BiliSakura commited on
Commit
a9c9521
·
verified ·
1 Parent(s): 418ab4a

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. JiT-B-16/model_index.json +9 -2
  3. JiT-B-16/pipeline.py +460 -0
  4. JiT-B-16/scheduler/scheduler_config.json +7 -0
  5. JiT-B-16/scheduler/scheduling_jit.py +161 -0
  6. JiT-B-16/transformer/config.json +18 -0
  7. JiT-B-16/transformer/diffusion_pytorch_model.safetensors +3 -0
  8. JiT-B-16/transformer/jit_transformer_2d.py +500 -0
  9. JiT-B-32/model_index.json +9 -2
  10. JiT-B-32/pipeline.py +460 -0
  11. JiT-B-32/scheduler/scheduler_config.json +7 -0
  12. JiT-B-32/scheduler/scheduling_jit.py +161 -0
  13. JiT-B-32/transformer/config.json +18 -0
  14. JiT-B-32/transformer/diffusion_pytorch_model.safetensors +3 -0
  15. JiT-B-32/transformer/jit_transformer_2d.py +500 -0
  16. JiT-H-16/model_index.json +9 -2
  17. JiT-H-16/pipeline.py +460 -0
  18. JiT-H-16/scheduler/scheduler_config.json +7 -0
  19. JiT-H-16/scheduler/scheduling_jit.py +161 -0
  20. JiT-H-16/transformer/config.json +18 -0
  21. JiT-H-16/transformer/diffusion_pytorch_model.safetensors +3 -0
  22. JiT-H-16/transformer/jit_transformer_2d.py +500 -0
  23. JiT-H-32/model_index.json +10 -3
  24. JiT-H-32/pipeline.py +460 -0
  25. JiT-H-32/scheduler/scheduler_config.json +7 -0
  26. JiT-H-32/scheduler/scheduling_jit.py +161 -0
  27. JiT-H-32/transformer/config.json +18 -0
  28. JiT-H-32/transformer/diffusion_pytorch_model.safetensors +3 -0
  29. JiT-H-32/transformer/jit_transformer_2d.py +500 -0
  30. JiT-L-16/model_index.json +9 -2
  31. JiT-L-16/pipeline.py +460 -0
  32. JiT-L-16/scheduler/scheduler_config.json +7 -0
  33. JiT-L-16/scheduler/scheduling_jit.py +161 -0
  34. JiT-L-16/transformer/config.json +18 -0
  35. JiT-L-16/transformer/diffusion_pytorch_model.safetensors +3 -0
  36. JiT-L-16/transformer/jit_transformer_2d.py +500 -0
  37. JiT-L-32/model_index.json +9 -2
  38. JiT-L-32/pipeline.py +460 -0
  39. JiT-L-32/scheduler/scheduler_config.json +7 -0
  40. JiT-L-32/scheduler/scheduling_jit.py +161 -0
  41. JiT-L-32/transformer/config.json +18 -0
  42. JiT-L-32/transformer/diffusion_pytorch_model.safetensors +3 -0
  43. JiT-L-32/transformer/jit_transformer_2d.py +500 -0
  44. README.md +44 -54
  45. demo.png +2 -2
  46. demo_images/jit_h32_final_test.png +3 -0
  47. demo_images/jit_h32_test_inference.png +2 -2
  48. labels/__pycache__/imagenet_labels.cpython-312.pyc +0 -0
  49. labels/id2label_cn.json +1002 -0
  50. labels/id2label_en.json +1002 -0
.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  demo.png filter=lfs diff=lfs merge=lfs -text
37
  demo_images/jit_h32_test_inference.png filter=lfs diff=lfs merge=lfs -text
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  demo.png filter=lfs diff=lfs merge=lfs -text
37
  demo_images/jit_h32_test_inference.png filter=lfs diff=lfs merge=lfs -text
38
+ demo_images/jit_h32_final_test.png filter=lfs diff=lfs merge=lfs -text
JiT-B-16/model_index.json CHANGED
@@ -1,8 +1,15 @@
1
  {
2
- "_class_name": "JiTPipeline",
 
 
 
3
  "_diffusers_version": "0.36.0",
 
 
 
 
4
  "transformer": [
5
- "jit_diffusers",
6
  "JiTTransformer2DModel"
7
  ]
8
  }
 
1
  {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "JiTPipeline"
5
+ ],
6
  "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "scheduling_jit",
9
+ "JiTScheduler"
10
+ ],
11
  "transformer": [
12
+ "jit_transformer_2d",
13
  "JiTTransformer2DModel"
14
  ]
15
  }
JiT-B-16/pipeline.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import importlib
18
+ import json
19
+ import sys
20
+ from pathlib import Path
21
+ from typing import Dict, List, Optional, Tuple, Union
22
+
23
+ import torch
24
+
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
26
+ from diffusers.utils.torch_utils import randn_tensor
27
+
28
+
29
+ RECOMMENDED_NOISE_BY_SIZE = {
30
+ 256: 1.0,
31
+ 512: 2.0,
32
+ }
33
+
34
+
35
+ class JiTPipeline(DiffusionPipeline):
36
+ r"""
37
+ Pipeline for image generation using JiT (Just image Transformer).
38
+
39
+ Parameters:
40
+ transformer ([`JiTTransformer2DModel`]):
41
+ A class-conditioned `JiTTransformer2DModel` to denoise the images.
42
+ scheduler ([`JiTScheduler`]):
43
+ Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler).
44
+ id2label (`dict[int, str]`, *optional*):
45
+ ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
46
+ id2label_cn (`dict[int, str]`, *optional*):
47
+ ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms.
48
+ """
49
+
50
+ model_cpu_offload_seq = "transformer"
51
+
52
+ @classmethod
53
+ def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
54
+ """Load a self-contained variant folder locally or from the Hub.
55
+
56
+ Examples:
57
+ JiTPipeline.from_pretrained(".")
58
+ JiTPipeline.from_pretrained("./JiT-H-32")
59
+ DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
60
+ """
61
+ repo_root = Path(__file__).resolve().parent
62
+
63
+ if pretrained_model_name_or_path in (None, "", "."):
64
+ variant = repo_root
65
+ elif (
66
+ isinstance(pretrained_model_name_or_path, str)
67
+ and "/" in pretrained_model_name_or_path
68
+ and not Path(pretrained_model_name_or_path).exists()
69
+ ):
70
+ from huggingface_hub import snapshot_download
71
+
72
+ hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
73
+ if subfolder:
74
+ hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"])
75
+ cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
76
+ variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
77
+ else:
78
+ variant = Path(pretrained_model_name_or_path)
79
+ if not variant.is_absolute():
80
+ candidate = (Path.cwd() / variant).resolve()
81
+ variant = candidate if candidate.exists() else (repo_root / variant).resolve()
82
+ if subfolder:
83
+ variant = variant / subfolder
84
+
85
+ model_kwargs = dict(kwargs)
86
+ inserted: List[str] = []
87
+
88
+ def _load_component(folder: str, module_name: str, class_name: str):
89
+ comp_dir = variant / folder
90
+ module_path = comp_dir / f"{module_name}.py"
91
+ has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
92
+ if not module_path.exists() or not has_weights:
93
+ return None
94
+
95
+ comp_path = str(comp_dir)
96
+ if comp_path not in sys.path:
97
+ sys.path.insert(0, comp_path)
98
+ inserted.append(comp_path)
99
+
100
+ module = importlib.import_module(module_name)
101
+ component_cls = getattr(module, class_name)
102
+ return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
103
+
104
+ try:
105
+ transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
106
+ scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler")
107
+
108
+ if transformer is None:
109
+ raise ValueError(f"No loadable transformer found under {variant}")
110
+
111
+ variant_path = str(variant)
112
+ id2label, id2label_cn = cls._load_labels_for_variant(variant_path)
113
+
114
+ pipe = cls(
115
+ transformer=transformer,
116
+ scheduler=scheduler,
117
+ id2label=id2label,
118
+ id2label_cn=id2label_cn,
119
+ )
120
+ if variant_path and hasattr(pipe, "register_to_config"):
121
+ pipe.register_to_config(_name_or_path=variant_path)
122
+ return pipe
123
+ finally:
124
+ for comp_path in inserted:
125
+ if comp_path in sys.path:
126
+ sys.path.remove(comp_path)
127
+
128
+ def __init__(
129
+ self,
130
+ transformer,
131
+ scheduler,
132
+ id2label: Optional[Dict[int, str]] = None,
133
+ id2label_cn: Optional[Dict[int, str]] = None,
134
+ ):
135
+ super().__init__()
136
+ self.register_modules(transformer=transformer, scheduler=scheduler)
137
+
138
+ self._id2label = id2label or {}
139
+ self._id2label_cn = id2label_cn or {}
140
+ self.labels = self._build_label2id(self._id2label)
141
+ self.labels_cn = self._build_label2id(self._id2label_cn)
142
+
143
+ def _ensure_labels_loaded(self) -> None:
144
+ if self._id2label or self._id2label_cn:
145
+ return
146
+ loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None))
147
+ if loaded_en:
148
+ self._id2label = loaded_en
149
+ self.labels = self._build_label2id(self._id2label)
150
+ if loaded_cn:
151
+ self._id2label_cn = loaded_cn
152
+ self.labels_cn = self._build_label2id(self._id2label_cn)
153
+
154
+ @staticmethod
155
+ def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]:
156
+ if not variant_path:
157
+ return None
158
+ variant_dir = Path(variant_path).resolve()
159
+ labels_dir = variant_dir.parent / "labels"
160
+ return labels_dir if labels_dir.is_dir() else None
161
+
162
+ @staticmethod
163
+ def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]:
164
+ filename = "id2label_en.json" if lang == "en" else "id2label_cn.json"
165
+ path = labels_dir / filename
166
+ if not path.exists():
167
+ raise FileNotFoundError(path)
168
+ raw = json.loads(path.read_text(encoding="utf-8"))
169
+ return {int(key): value for key, value in raw.items()}
170
+
171
+ @classmethod
172
+ def _load_labels_for_variant(
173
+ cls,
174
+ variant_path: Optional[str],
175
+ ) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]:
176
+ labels_dir = cls._labels_dir_for_variant(variant_path)
177
+ if labels_dir is None:
178
+ return None, None
179
+ try:
180
+ return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn")
181
+ except FileNotFoundError:
182
+ return None, None
183
+
184
+ @staticmethod
185
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
186
+ label2id: Dict[str, int] = {}
187
+ for class_id, value in id2label.items():
188
+ for synonym in value.split(","):
189
+ synonym = synonym.strip()
190
+ if synonym:
191
+ label2id[synonym] = int(class_id)
192
+ return dict(sorted(label2id.items()))
193
+
194
+ @property
195
+ def id2label(self) -> Dict[int, str]:
196
+ """ImageNet class id to English label string (comma-separated synonyms)."""
197
+ self._ensure_labels_loaded()
198
+ return self._id2label
199
+
200
+ @property
201
+ def id2label_cn(self) -> Dict[int, str]:
202
+ """ImageNet class id to Chinese label string (comma-separated synonyms)."""
203
+ self._ensure_labels_loaded()
204
+ return self._id2label_cn
205
+
206
+ def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]:
207
+ r"""
208
+ Map ImageNet label strings to class ids.
209
+
210
+ Args:
211
+ label (`str` or `list[str]`):
212
+ One or more label strings. Each string must match a synonym in `id2label` (English)
213
+ or `id2label_cn` (Chinese).
214
+ lang (`str`, *optional*, defaults to `"en"`):
215
+ `"en"` uses English synonyms; `"cn"` uses Chinese synonyms.
216
+ """
217
+ if lang not in ("en", "cn"):
218
+ raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.")
219
+
220
+ self._ensure_labels_loaded()
221
+ label2id = self.labels if lang == "en" else self.labels_cn
222
+ if not label2id:
223
+ raise ValueError(
224
+ f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
225
+ )
226
+
227
+ if isinstance(label, str):
228
+ label = [label]
229
+
230
+ missing = [item for item in label if item not in label2id]
231
+ if missing:
232
+ preview = ", ".join(list(label2id.keys())[:8])
233
+ raise ValueError(
234
+ f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
235
+ )
236
+ return [label2id[item] for item in label]
237
+
238
+ def _normalize_class_labels(
239
+ self,
240
+ class_labels: Union[int, str, List[Union[int, str]]],
241
+ ) -> List[int]:
242
+ if isinstance(class_labels, int):
243
+ return [class_labels]
244
+
245
+ if isinstance(class_labels, str):
246
+ return self.get_label_ids(class_labels)
247
+
248
+ if class_labels and isinstance(class_labels[0], str):
249
+ self._ensure_labels_loaded()
250
+ if all(label in self.labels for label in class_labels):
251
+ return self.get_label_ids(class_labels, lang="en")
252
+ if all(label in self.labels_cn for label in class_labels):
253
+ return self.get_label_ids(class_labels, lang="cn")
254
+ raise ValueError(
255
+ "Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` "
256
+ "or Chinese synonyms from `pipe.labels_cn`."
257
+ )
258
+
259
+ return list(class_labels)
260
+
261
+ def _predict_velocity(
262
+ self,
263
+ z_value: torch.Tensor,
264
+ t: torch.Tensor,
265
+ class_labels: torch.Tensor,
266
+ class_null: torch.Tensor,
267
+ do_classifier_free_guidance: bool,
268
+ guidance_scale: float,
269
+ guidance_interval_min: float,
270
+ guidance_interval_max: float,
271
+ ) -> torch.Tensor:
272
+ t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype)
273
+ if do_classifier_free_guidance:
274
+ z_in = torch.cat([z_value, z_value], dim=0)
275
+ labels = torch.cat([class_labels, class_null], dim=0)
276
+ else:
277
+ z_in = z_value
278
+ labels = class_labels
279
+
280
+ t_batch = t.flatten().expand(z_in.shape[0])
281
+ x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample
282
+ v = self.scheduler.velocity_from_prediction(z_in, x_pred, t)
283
+
284
+ if not do_classifier_free_guidance:
285
+ return v
286
+
287
+ v_cond, v_uncond = v.chunk(2, dim=0)
288
+ interval_mask = t < guidance_interval_max
289
+ if guidance_interval_min != 0.0:
290
+ interval_mask = interval_mask & (t > guidance_interval_min)
291
+ scale = torch.where(
292
+ interval_mask,
293
+ torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype),
294
+ torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype),
295
+ )
296
+ return v_uncond + scale * (v_cond - v_uncond)
297
+
298
+ def _run_sampler(
299
+ self,
300
+ latents: torch.Tensor,
301
+ class_labels: torch.Tensor,
302
+ class_null: torch.Tensor,
303
+ num_inference_steps: int,
304
+ do_classifier_free_guidance: bool,
305
+ guidance_scale: float,
306
+ guidance_interval_min: float,
307
+ guidance_interval_max: float,
308
+ sampling_method: str,
309
+ ) -> torch.Tensor:
310
+ device = latents.device
311
+ self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method)
312
+ timesteps = self.scheduler.timesteps
313
+
314
+ for i in self.progress_bar(range(num_inference_steps - 1)):
315
+ t = timesteps[i]
316
+ t_next = timesteps[i + 1]
317
+ v = self._predict_velocity(
318
+ latents,
319
+ t,
320
+ class_labels,
321
+ class_null,
322
+ do_classifier_free_guidance,
323
+ guidance_scale,
324
+ guidance_interval_min,
325
+ guidance_interval_max,
326
+ )
327
+
328
+ if sampling_method == "heun":
329
+ latents_euler = latents + (t_next - t) * v
330
+ v_next = self._predict_velocity(
331
+ latents_euler,
332
+ t_next,
333
+ class_labels,
334
+ class_null,
335
+ do_classifier_free_guidance,
336
+ guidance_scale,
337
+ guidance_interval_min,
338
+ guidance_interval_max,
339
+ )
340
+ latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample
341
+ else:
342
+ latents = self.scheduler.step(v, t, latents).prev_sample
343
+
344
+ t = timesteps[-2]
345
+ t_next = timesteps[-1]
346
+ v = self._predict_velocity(
347
+ latents,
348
+ t,
349
+ class_labels,
350
+ class_null,
351
+ do_classifier_free_guidance,
352
+ guidance_scale,
353
+ guidance_interval_min,
354
+ guidance_interval_max,
355
+ )
356
+ return latents + (t_next - t) * v
357
+
358
+ @torch.inference_mode()
359
+ def __call__(
360
+ self,
361
+ class_labels: Union[int, str, List[Union[int, str]]],
362
+ guidance_scale: Optional[float] = None,
363
+ guidance_interval_min: float = 0.1,
364
+ guidance_interval_max: float = 1.0,
365
+ noise_scale: Optional[float] = None,
366
+ t_eps: Optional[float] = None,
367
+ sampling_method: Optional[str] = None,
368
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
369
+ num_inference_steps: int = 50,
370
+ output_type: Optional[str] = "pil",
371
+ return_dict: bool = True,
372
+ ) -> Union[ImagePipelineOutput, Tuple]:
373
+ r"""
374
+ Generate class-conditional images.
375
+
376
+ Args:
377
+ class_labels (`int`, `str`, `list[int]`, or `list[str]`):
378
+ ImageNet class indices or human-readable label strings (English or Chinese).
379
+ guidance_scale (`float`, *optional*):
380
+ Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
381
+ guidance_interval_min (`float`, defaults to `0.1`):
382
+ Lower bound of the CFG interval in flow time `t in [0, 1]`.
383
+ guidance_interval_max (`float`, defaults to `1.0`):
384
+ Upper bound of the CFG interval in flow time.
385
+ noise_scale (`float`, *optional*):
386
+ Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
387
+ t_eps (`float`, *optional*):
388
+ Epsilon clamp for the `1 - t` denominator (scheduler config by default).
389
+ sampling_method (`str`, *optional*):
390
+ `"heun"` or `"euler"`. Defaults to the scheduler config (`heun`).
391
+ generator (`torch.Generator`, *optional*):
392
+ RNG for reproducibility.
393
+ num_inference_steps (`int`, defaults to `50`):
394
+ Number of solver steps (at least 2).
395
+ output_type (`str`, *optional*, defaults to `"pil"`):
396
+ `"pil"`, `"np"`, or `"pt"`.
397
+ return_dict (`bool`, *optional*, defaults to `True`):
398
+ Return [`ImagePipelineOutput`] if True.
399
+ """
400
+ solver = sampling_method or self.scheduler.config.solver
401
+ if solver not in {"heun", "euler"}:
402
+ raise ValueError("sampling_method must be one of: 'heun', 'euler'.")
403
+ if num_inference_steps < 2:
404
+ raise ValueError("num_inference_steps must be >= 2.")
405
+
406
+ if t_eps is not None:
407
+ self.scheduler.register_to_config(t_eps=t_eps)
408
+
409
+ class_label_ids = self._normalize_class_labels(class_labels)
410
+ do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
411
+
412
+ batch_size = len(class_label_ids)
413
+ image_size = int(self.transformer.config.sample_size)
414
+ channels = int(self.transformer.config.in_channels)
415
+ null_class_val = int(self.transformer.config.num_classes)
416
+
417
+ if guidance_scale is None:
418
+ guidance_scale = 1.0
419
+ if noise_scale is None:
420
+ noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0)
421
+
422
+ latents = (
423
+ randn_tensor(
424
+ shape=(batch_size, channels, image_size, image_size),
425
+ generator=generator,
426
+ device=self._execution_device,
427
+ dtype=self.transformer.dtype,
428
+ )
429
+ * noise_scale
430
+ )
431
+
432
+ class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
433
+ class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
434
+ class_null = torch.full_like(class_labels_t, null_class_val)
435
+
436
+ latents = self._run_sampler(
437
+ latents,
438
+ class_labels_t,
439
+ class_null,
440
+ num_inference_steps,
441
+ do_classifier_free_guidance,
442
+ guidance_scale,
443
+ guidance_interval_min,
444
+ guidance_interval_max,
445
+ solver,
446
+ )
447
+
448
+ images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
449
+ if output_type == "pt":
450
+ images = images_pt
451
+ elif output_type == "np":
452
+ images = images_pt.permute(0, 2, 3, 1).numpy()
453
+ else:
454
+ images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy())
455
+
456
+ self.maybe_free_model_hooks()
457
+
458
+ if not return_dict:
459
+ return (images,)
460
+ return ImagePipelineOutput(images=images)
JiT-B-16/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "JiTScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "t_eps": 0.05,
6
+ "solver": "heun"
7
+ }
JiT-B-16/scheduler/scheduling_jit.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
22
+ from diffusers.utils import BaseOutput
23
+
24
+
25
+ @dataclass
26
+ class JiTSchedulerOutput(BaseOutput):
27
+ """
28
+ Output class for the JiT scheduler's `step` function.
29
+
30
+ Args:
31
+ prev_sample (`torch.Tensor`):
32
+ Updated sample after one solver step along the JiT flow-time grid.
33
+ """
34
+
35
+ prev_sample: torch.Tensor
36
+
37
+
38
+ class JiTScheduler(SchedulerMixin, ConfigMixin):
39
+ """
40
+ Manual flow-matching scheduler for JiT checkpoints.
41
+
42
+ Uses a linear flow-time grid `t in [0, 1]` (increasing), matching the official JiT
43
+ sampler. Velocity is `v = (x_pred - z) / (1 - t)`; integration is explicit Euler or
44
+ Heun along that grid.
45
+ """
46
+
47
+ order = 2
48
+
49
+ @register_to_config
50
+ def __init__(
51
+ self,
52
+ num_train_timesteps: int = 1000,
53
+ t_eps: float = 5e-2,
54
+ solver: str = "heun",
55
+ ):
56
+ if solver not in {"heun", "euler"}:
57
+ raise ValueError("solver must be one of: 'heun', 'euler'.")
58
+ self.timesteps: Optional[torch.Tensor] = None
59
+ self.sigmas: Optional[List[float]] = None
60
+ self.num_inference_steps: Optional[int] = None
61
+ self._step_index: Optional[int] = None
62
+
63
+ @property
64
+ def init_noise_sigma(self) -> float:
65
+ return 1.0
66
+
67
+ def set_timesteps(
68
+ self,
69
+ num_inference_steps: int,
70
+ device: Union[str, torch.device, None] = None,
71
+ solver: Optional[str] = None,
72
+ ) -> None:
73
+ if num_inference_steps < 2:
74
+ raise ValueError("num_inference_steps must be >= 2.")
75
+
76
+ self.num_inference_steps = num_inference_steps
77
+ self.timesteps = torch.linspace(
78
+ 0.0,
79
+ 1.0,
80
+ num_inference_steps + 1,
81
+ device=device,
82
+ dtype=torch.float32,
83
+ )
84
+ sigma_grid = torch.linspace(0.0, 1.0, num_inference_steps, device=device, dtype=torch.float32)
85
+ self.sigmas = (1.0 - sigma_grid).tolist()
86
+ self._step_index = 0
87
+ if solver is not None:
88
+ self.register_to_config(solver=solver)
89
+
90
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
91
+ del timestep
92
+ return sample
93
+
94
+ def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int:
95
+ if self._step_index is not None:
96
+ return self._step_index
97
+ if self.timesteps is None:
98
+ raise ValueError("Call `set_timesteps` before `step`.")
99
+ if timestep is None:
100
+ return 0
101
+ t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0])
102
+ matches = (self.timesteps - t_value).abs() < 1e-6
103
+ if matches.any():
104
+ return int(matches.nonzero(as_tuple=False)[0].item())
105
+ return 0
106
+
107
+ def step(
108
+ self,
109
+ model_output: torch.Tensor,
110
+ timestep: Union[float, torch.Tensor, None],
111
+ sample: torch.Tensor,
112
+ model_output_next: Optional[torch.Tensor] = None,
113
+ return_dict: bool = True,
114
+ ) -> Union[JiTSchedulerOutput, Tuple[torch.Tensor]]:
115
+ """
116
+ Integrate one step on the linear `t` grid.
117
+
118
+ Args:
119
+ model_output (`torch.Tensor`):
120
+ Velocity `v = (x_pred - z) / (1 - t)` at the current time.
121
+ timestep (`float` or `torch.Tensor`, *optional*):
122
+ Current flow time `t`. When omitted, uses the internal step index.
123
+ sample (`torch.Tensor`):
124
+ Current noisy latent `z`.
125
+ model_output_next (`torch.Tensor`, *optional*):
126
+ Velocity at `t_next` (required for Heun intermediate steps).
127
+ """
128
+ if self.timesteps is None:
129
+ raise ValueError("Call `set_timesteps` before `step`.")
130
+
131
+ step_index = self._resolve_step_index(timestep)
132
+ if step_index >= len(self.timesteps) - 1:
133
+ raise ValueError("Scheduler has already reached the final timestep.")
134
+
135
+ t = self.timesteps[step_index]
136
+ t_next = self.timesteps[step_index + 1]
137
+ dt = t_next - t
138
+
139
+ if self.config.solver == "heun" and model_output_next is not None:
140
+ prev_sample = sample + dt * 0.5 * (model_output + model_output_next)
141
+ else:
142
+ prev_sample = sample + dt * model_output
143
+
144
+ self._step_index = step_index + 1
145
+
146
+ if not return_dict:
147
+ return (prev_sample,)
148
+ return JiTSchedulerOutput(prev_sample=prev_sample)
149
+
150
+ def velocity_from_prediction(
151
+ self,
152
+ sample: torch.Tensor,
153
+ x_pred: torch.Tensor,
154
+ timestep: Union[float, torch.Tensor],
155
+ ) -> torch.Tensor:
156
+ """Compute JiT velocity `v = (x_pred - z) / (1 - t)` with denominator clamp."""
157
+ t = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype)
158
+ while t.ndim < sample.ndim:
159
+ t = t.unsqueeze(-1)
160
+ denom = (1.0 - t).clamp_min(self.config.t_eps)
161
+ return (x_pred - sample) / denom
JiT-B-16/transformer/config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "JiTTransformer2DModel",
3
+ "_diffusers_version": "0.36.0",
4
+ "attention_dropout": 0.0,
5
+ "bottleneck_dim": 128,
6
+ "dropout": 0.0,
7
+ "hidden_size": 768,
8
+ "in_channels": 3,
9
+ "in_context_len": 32,
10
+ "in_context_start": 4,
11
+ "mlp_ratio": 4.0,
12
+ "norm_eps": 1e-06,
13
+ "num_attention_heads": 12,
14
+ "num_classes": 1000,
15
+ "num_layers": 12,
16
+ "patch_size": 16,
17
+ "sample_size": 256
18
+ }
JiT-B-16/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b68278f2e16a2842bbc17e7d38bc08d22475e1d748bb2e672a9b7e8aff5b4772
3
+ size 525298808
JiT-B-16/transformer/jit_transformer_2d.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
24
+ from diffusers.models.modeling_utils import ModelMixin
25
+ from diffusers.models.normalization import RMSNorm
26
+ from diffusers.utils import logging
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+
32
+ def broadcat(tensors, dim=-1):
33
+ num_tensors = len(tensors)
34
+ shape_lens = {len(t.shape) for t in tensors}
35
+ if len(shape_lens) != 1:
36
+ raise ValueError("tensors must all have the same number of dimensions")
37
+ shape_len = list(shape_lens)[0]
38
+ dim = (dim + shape_len) if dim < 0 else dim
39
+ dims = list(zip(*(list(t.shape) for t in tensors)))
40
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
41
+
42
+ if not all(len(set(t[1])) <= 2 for t in expandable_dims):
43
+ raise ValueError("invalid dimensions for broadcastable concatenation")
44
+
45
+ max_dims = [(t[0], max(t[1])) for t in expandable_dims]
46
+ expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
47
+ expanded_dims.insert(dim, (dim, dims[dim]))
48
+ expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
49
+ tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
50
+ return torch.cat(tensors, dim=dim)
51
+
52
+
53
+ def rotate_half(x):
54
+ x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2)
55
+ x1, x2 = x.unbind(dim=-1)
56
+ x = torch.stack((-x2, x1), dim=-1)
57
+ return x.view(*x.shape[:-2], -1)
58
+
59
+
60
+ class JiTRotaryEmbedding(nn.Module):
61
+ def __init__(
62
+ self,
63
+ dim,
64
+ pt_seq_len=16,
65
+ ft_seq_len=None,
66
+ custom_freqs=None,
67
+ theta=10000,
68
+ num_cls_token=0,
69
+ ):
70
+ super().__init__()
71
+ if custom_freqs is not None:
72
+ freqs = custom_freqs
73
+ else:
74
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
75
+
76
+ if ft_seq_len is None:
77
+ ft_seq_len = pt_seq_len
78
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
79
+
80
+ freqs = torch.einsum("..., f -> ... f", t, freqs)
81
+ freqs = freqs.repeat_interleave(2, dim=-1)
82
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
83
+
84
+ if num_cls_token > 0:
85
+ freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D]
86
+ cos_img = freqs_flat.cos()
87
+ sin_img = freqs_flat.sin()
88
+
89
+ # prepend in-context cls token
90
+ _, D = cos_img.shape
91
+ cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype)
92
+ sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype)
93
+
94
+ self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False)
95
+ self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False)
96
+ else:
97
+ self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False)
98
+ self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False)
99
+
100
+ def forward(self, t):
101
+ # Applied on (batch, seq_len, heads, head_dim) tensors from attention.
102
+ seq_len = t.shape[1]
103
+ freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
104
+ freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
105
+
106
+ return t * freqs_cos[:, None, :] + rotate_half(t) * freqs_sin[:, None, :]
107
+
108
+
109
+ def modulate(x, shift, scale):
110
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
111
+
112
+
113
+ class JiTPatchEmbed(nn.Module):
114
+ """Image to Patch Embedding with Bottleneck"""
115
+
116
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True):
117
+ super().__init__()
118
+ img_size = (img_size, img_size)
119
+ patch_size = (patch_size, patch_size)
120
+ self.img_size = img_size
121
+ self.patch_size = patch_size
122
+ self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
123
+
124
+ self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False)
125
+ self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias)
126
+
127
+ def forward(self, x):
128
+ x = self.proj2(self.proj1(x)).flatten(2).transpose(1, 2)
129
+ return x
130
+
131
+
132
+ class JiTTimestepEmbedder(nn.Module):
133
+ """
134
+ Embeds scalar timesteps into vector representations.
135
+ """
136
+
137
+ def __init__(self, hidden_size, frequency_embedding_size=256):
138
+ super().__init__()
139
+ self.mlp = nn.Sequential(
140
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
141
+ nn.SiLU(),
142
+ nn.Linear(hidden_size, hidden_size, bias=True),
143
+ )
144
+ self.frequency_embedding_size = frequency_embedding_size
145
+
146
+ @staticmethod
147
+ def timestep_embedding(t, dim, max_period=10000):
148
+ """
149
+ Create sinusoidal timestep embeddings.
150
+ """
151
+ half = dim // 2
152
+ freqs = torch.exp(
153
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
154
+ ).to(device=t.device)
155
+ args = t[:, None].float() * freqs[None]
156
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
157
+ if dim % 2:
158
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
159
+ return embedding
160
+
161
+ def forward(self, t, dtype=None):
162
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
163
+ if dtype is not None:
164
+ t_freq = t_freq.to(dtype=dtype)
165
+ t_emb = self.mlp(t_freq)
166
+ return t_emb
167
+
168
+
169
+ class JiTLabelEmbedder(nn.Module):
170
+ """
171
+ Embeds class labels into vector representations.
172
+ """
173
+
174
+ def __init__(self, num_classes, hidden_size):
175
+ super().__init__()
176
+ self.embedding_table = nn.Embedding(num_classes + 1, hidden_size)
177
+ self.num_classes = num_classes
178
+
179
+ def forward(self, labels):
180
+ embeddings = self.embedding_table(labels)
181
+ return embeddings
182
+
183
+
184
+ class JiTAttention(nn.Module):
185
+ def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
186
+ super().__init__()
187
+ self.num_heads = num_heads
188
+ head_dim = dim // num_heads
189
+
190
+ self.q_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
191
+ self.k_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
192
+
193
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
194
+ self.attn_drop = attn_drop
195
+ self.proj = nn.Linear(dim, dim)
196
+ self.proj_drop = nn.Dropout(proj_drop)
197
+
198
+ def forward(self, x, rope=None):
199
+ B, N, C = x.shape
200
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
201
+ q, k, v = qkv[0], qkv[1], qkv[2]
202
+
203
+ q = self.q_norm(q)
204
+ k = self.k_norm(k)
205
+
206
+ if rope is not None:
207
+ q = q.transpose(1, 2)
208
+ k = k.transpose(1, 2)
209
+ q = rope(q)
210
+ k = rope(k)
211
+ q = q.transpose(1, 2)
212
+ k = k.transpose(1, 2)
213
+
214
+ dropout_p = self.attn_drop if self.training else 0.0
215
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
216
+ x = x.transpose(1, 2).reshape(B, N, C)
217
+ x = self.proj(x)
218
+ x = self.proj_drop(x)
219
+ return x
220
+
221
+
222
+ class JiTSwiGLUFFN(nn.Module):
223
+ def __init__(self, dim: int, hidden_dim: int, drop=0.0, bias=True) -> None:
224
+ super().__init__()
225
+ hidden_dim = int(hidden_dim * 2 / 3)
226
+ self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias)
227
+ self.w3 = nn.Linear(hidden_dim, dim, bias=bias)
228
+ self.ffn_dropout = nn.Dropout(drop)
229
+
230
+ def forward(self, x):
231
+ x12 = self.w12(x)
232
+ x1, x2 = x12.chunk(2, dim=-1)
233
+ hidden = F.silu(x1) * x2
234
+ return self.w3(self.ffn_dropout(hidden))
235
+
236
+
237
+ class JiTBlock(nn.Module):
238
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
239
+ super().__init__()
240
+ self.norm1 = RMSNorm(hidden_size, eps=eps)
241
+ self.attn = JiTAttention(
242
+ hidden_size,
243
+ num_heads=num_heads,
244
+ qkv_bias=True,
245
+ qk_norm=True,
246
+ attn_drop=attn_drop,
247
+ proj_drop=proj_drop,
248
+ eps=eps,
249
+ )
250
+ self.norm2 = RMSNorm(hidden_size, eps=eps)
251
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
252
+ self.mlp = JiTSwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop)
253
+
254
+ self.act = nn.SiLU()
255
+ self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
256
+
257
+ def forward(self, x, c, feat_rope=None):
258
+ # Apply activation
259
+ c = self.act(c)
260
+
261
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
262
+
263
+ # Attention block
264
+ norm_x = self.norm1(x)
265
+ modulated_x = modulate(norm_x, shift_msa, scale_msa)
266
+ attn_out = self.attn(modulated_x, rope=feat_rope)
267
+ x = x + gate_msa.unsqueeze(1) * attn_out
268
+
269
+ # MLP block
270
+ norm_x = self.norm2(x)
271
+ modulated_x = modulate(norm_x, shift_mlp, scale_mlp)
272
+ mlp_out = self.mlp(modulated_x)
273
+ x = x + gate_mlp.unsqueeze(1) * mlp_out
274
+
275
+ return x
276
+
277
+
278
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
279
+ grid_h = np.arange(grid_size, dtype=np.float32)
280
+ grid_w = np.arange(grid_size, dtype=np.float32)
281
+ grid = np.meshgrid(grid_w, grid_h)
282
+ grid = np.stack(grid, axis=0)
283
+ grid = grid.reshape([2, 1, grid_size, grid_size])
284
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
285
+ if cls_token and extra_tokens > 0:
286
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
287
+ return pos_embed
288
+
289
+
290
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
291
+ if embed_dim % 2 != 0:
292
+ raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
293
+
294
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
295
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
296
+ emb = np.concatenate([emb_h, emb_w], axis=1)
297
+ return emb
298
+
299
+
300
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
301
+ if embed_dim % 2 != 0:
302
+ raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
303
+
304
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
305
+ omega /= embed_dim / 2.0
306
+ omega = 1.0 / 10000**omega
307
+
308
+ pos = pos.reshape(-1)
309
+ out = np.einsum("m,d->md", pos, omega)
310
+
311
+ emb_sin = np.sin(out)
312
+ emb_cos = np.cos(out)
313
+
314
+ emb = np.concatenate([emb_sin, emb_cos], axis=1)
315
+ return emb
316
+
317
+
318
+ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
319
+ r"""
320
+ A 2D Transformer for pixel-space class-conditional generation with JiT
321
+ ([Back to Basics: Let Denoising Generative Models Denoise](https://arxiv.org/abs/2511.13720)).
322
+
323
+ Parameters:
324
+ sample_size (`int`, defaults to `256`):
325
+ Input image resolution (height and width).
326
+ patch_size (`int`, defaults to `16`):
327
+ Patch size for the bottleneck patch embedder.
328
+ in_channels (`int`, defaults to `3`):
329
+ Number of input image channels.
330
+ hidden_size (`int`, defaults to `768`):
331
+ Transformer hidden dimension.
332
+ num_layers (`int`, defaults to `12`):
333
+ Number of JiT transformer blocks.
334
+ num_attention_heads (`int`, defaults to `12`):
335
+ Number of attention heads per block.
336
+ mlp_ratio (`float`, defaults to `4.0`):
337
+ MLP hidden dimension multiplier.
338
+ attention_dropout (`float`, defaults to `0.0`):
339
+ Attention dropout in the middle quarter of blocks.
340
+ dropout (`float`, defaults to `0.0`):
341
+ Projection dropout in the middle quarter of blocks.
342
+ num_classes (`int`, defaults to `1000`):
343
+ Number of class labels (null label uses index `num_classes` for CFG).
344
+ bottleneck_dim (`int`, defaults to `128`):
345
+ PCA bottleneck dimension in the patch embedder.
346
+ in_context_len (`int`, defaults to `32`):
347
+ Number of in-context class tokens prepended mid-network.
348
+ in_context_start (`int`, defaults to `4`):
349
+ Block index at which in-context tokens are inserted.
350
+ norm_eps (`float`, defaults to `1e-6`):
351
+ Epsilon for RMSNorm layers.
352
+ """
353
+
354
+ _supports_gradient_checkpointing = True
355
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
356
+
357
+ @register_to_config
358
+ def __init__(
359
+ self,
360
+ sample_size: int = 256,
361
+ patch_size: int = 16,
362
+ in_channels: int = 3,
363
+ hidden_size: int = 768,
364
+ num_layers: int = 12,
365
+ num_attention_heads: int = 12,
366
+ mlp_ratio: float = 4.0,
367
+ attention_dropout: float = 0.0,
368
+ dropout: float = 0.0,
369
+ num_classes: int = 1000,
370
+ bottleneck_dim: int = 128,
371
+ in_context_len: int = 32,
372
+ in_context_start: int = 4,
373
+ norm_eps: float = 1e-6,
374
+ ):
375
+ super().__init__()
376
+ self.sample_size = sample_size
377
+ self.patch_size = patch_size
378
+ self.in_channels = in_channels
379
+ self.out_channels = in_channels
380
+ self.hidden_size = hidden_size
381
+ self.num_layers = num_layers
382
+ self.num_attention_heads = num_attention_heads
383
+ self.in_context_len = in_context_len
384
+ self.in_context_start = in_context_start
385
+ self.norm_eps = norm_eps
386
+ self.gradient_checkpointing = False
387
+
388
+ # Time and Class Embedding
389
+ self.t_embedder = JiTTimestepEmbedder(hidden_size)
390
+ self.y_embedder = JiTLabelEmbedder(num_classes, hidden_size)
391
+
392
+ # Patch Embedding
393
+ self.x_embedder = JiTPatchEmbed(
394
+ img_size=sample_size,
395
+ patch_size=patch_size,
396
+ in_chans=in_channels,
397
+ pca_dim=bottleneck_dim,
398
+ embed_dim=hidden_size,
399
+ bias=True,
400
+ )
401
+
402
+ # Positional Embedding (Fixed Sin-Cos)
403
+ num_patches = self.x_embedder.num_patches
404
+ pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches**0.5))
405
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
406
+
407
+ # In-context Embedding
408
+ if self.in_context_len > 0:
409
+ self.in_context_posemb = nn.Parameter(torch.zeros(1, self.in_context_len, hidden_size))
410
+
411
+ # RoPE
412
+ half_head_dim = hidden_size // num_attention_heads // 2
413
+ hw_seq_len = sample_size // patch_size
414
+ self.feat_rope = JiTRotaryEmbedding(dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=0)
415
+ self.feat_rope_incontext = JiTRotaryEmbedding(
416
+ dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=self.in_context_len
417
+ )
418
+
419
+ # Blocks
420
+ self.blocks = nn.ModuleList(
421
+ [
422
+ JiTBlock(
423
+ hidden_size,
424
+ num_attention_heads,
425
+ mlp_ratio=mlp_ratio,
426
+ attn_drop=attention_dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
427
+ proj_drop=dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
428
+ eps=norm_eps,
429
+ )
430
+ for i in range(num_layers)
431
+ ]
432
+ )
433
+
434
+ # Final Layer
435
+ self.norm_final = RMSNorm(hidden_size, eps=norm_eps)
436
+ self.linear_final = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
437
+ self.act_final = nn.SiLU()
438
+ self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
439
+
440
+ def forward(
441
+ self,
442
+ hidden_states: torch.Tensor,
443
+ timestep: torch.LongTensor,
444
+ class_labels: torch.LongTensor,
445
+ return_dict: bool = True,
446
+ ):
447
+
448
+ t_emb = self.t_embedder(timestep, dtype=hidden_states.dtype)
449
+ y_emb = self.y_embedder(class_labels)
450
+
451
+ # Ensure embeddings match hidden_states dtype
452
+ y_emb = y_emb.to(dtype=hidden_states.dtype)
453
+
454
+ c = t_emb + y_emb
455
+
456
+ # Patch Embed
457
+ x = self.x_embedder(hidden_states)
458
+ x = x + self.pos_embed.to(x.dtype)
459
+
460
+ # Blocks
461
+ for i, block in enumerate(self.blocks):
462
+ if self.in_context_len > 0 and i == self.in_context_start:
463
+ in_context_tokens = y_emb.unsqueeze(1).repeat(1, self.in_context_len, 1)
464
+ in_context_tokens = in_context_tokens + self.in_context_posemb.to(in_context_tokens.dtype)
465
+ x = torch.cat([in_context_tokens, x], dim=1)
466
+
467
+ rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
468
+
469
+ if self.training and self.gradient_checkpointing:
470
+ x = torch.utils.checkpoint.checkpoint(
471
+ block,
472
+ x,
473
+ c,
474
+ rope,
475
+ use_reentrant=False,
476
+ )
477
+ else:
478
+ x = block(x, c, feat_rope=rope)
479
+
480
+ # Slice off in-context tokens
481
+ if self.in_context_len > 0:
482
+ x = x[:, self.in_context_len :]
483
+
484
+ # Final Layer
485
+ c = self.act_final(c)
486
+ shift, scale = self.adaLN_modulation_final(c).chunk(2, dim=1)
487
+
488
+ x = modulate(self.norm_final(x), shift, scale)
489
+ x = self.linear_final(x)
490
+
491
+ # Unpatchify
492
+ h = w = int(x.shape[1] ** 0.5)
493
+ x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels))
494
+ x = torch.einsum("nhwpqc->nchpwq", x)
495
+ output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size))
496
+
497
+ if not return_dict:
498
+ return (output,)
499
+
500
+ return Transformer2DModelOutput(sample=output)
JiT-B-32/model_index.json CHANGED
@@ -1,8 +1,15 @@
1
  {
2
- "_class_name": "JiTPipeline",
 
 
 
3
  "_diffusers_version": "0.36.0",
 
 
 
 
4
  "transformer": [
5
- "jit_diffusers",
6
  "JiTTransformer2DModel"
7
  ]
8
  }
 
1
  {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "JiTPipeline"
5
+ ],
6
  "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "scheduling_jit",
9
+ "JiTScheduler"
10
+ ],
11
  "transformer": [
12
+ "jit_transformer_2d",
13
  "JiTTransformer2DModel"
14
  ]
15
  }
JiT-B-32/pipeline.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import importlib
18
+ import json
19
+ import sys
20
+ from pathlib import Path
21
+ from typing import Dict, List, Optional, Tuple, Union
22
+
23
+ import torch
24
+
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
26
+ from diffusers.utils.torch_utils import randn_tensor
27
+
28
+
29
+ RECOMMENDED_NOISE_BY_SIZE = {
30
+ 256: 1.0,
31
+ 512: 2.0,
32
+ }
33
+
34
+
35
+ class JiTPipeline(DiffusionPipeline):
36
+ r"""
37
+ Pipeline for image generation using JiT (Just image Transformer).
38
+
39
+ Parameters:
40
+ transformer ([`JiTTransformer2DModel`]):
41
+ A class-conditioned `JiTTransformer2DModel` to denoise the images.
42
+ scheduler ([`JiTScheduler`]):
43
+ Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler).
44
+ id2label (`dict[int, str]`, *optional*):
45
+ ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
46
+ id2label_cn (`dict[int, str]`, *optional*):
47
+ ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms.
48
+ """
49
+
50
+ model_cpu_offload_seq = "transformer"
51
+
52
+ @classmethod
53
+ def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
54
+ """Load a self-contained variant folder locally or from the Hub.
55
+
56
+ Examples:
57
+ JiTPipeline.from_pretrained(".")
58
+ JiTPipeline.from_pretrained("./JiT-H-32")
59
+ DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
60
+ """
61
+ repo_root = Path(__file__).resolve().parent
62
+
63
+ if pretrained_model_name_or_path in (None, "", "."):
64
+ variant = repo_root
65
+ elif (
66
+ isinstance(pretrained_model_name_or_path, str)
67
+ and "/" in pretrained_model_name_or_path
68
+ and not Path(pretrained_model_name_or_path).exists()
69
+ ):
70
+ from huggingface_hub import snapshot_download
71
+
72
+ hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
73
+ if subfolder:
74
+ hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"])
75
+ cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
76
+ variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
77
+ else:
78
+ variant = Path(pretrained_model_name_or_path)
79
+ if not variant.is_absolute():
80
+ candidate = (Path.cwd() / variant).resolve()
81
+ variant = candidate if candidate.exists() else (repo_root / variant).resolve()
82
+ if subfolder:
83
+ variant = variant / subfolder
84
+
85
+ model_kwargs = dict(kwargs)
86
+ inserted: List[str] = []
87
+
88
+ def _load_component(folder: str, module_name: str, class_name: str):
89
+ comp_dir = variant / folder
90
+ module_path = comp_dir / f"{module_name}.py"
91
+ has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
92
+ if not module_path.exists() or not has_weights:
93
+ return None
94
+
95
+ comp_path = str(comp_dir)
96
+ if comp_path not in sys.path:
97
+ sys.path.insert(0, comp_path)
98
+ inserted.append(comp_path)
99
+
100
+ module = importlib.import_module(module_name)
101
+ component_cls = getattr(module, class_name)
102
+ return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
103
+
104
+ try:
105
+ transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
106
+ scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler")
107
+
108
+ if transformer is None:
109
+ raise ValueError(f"No loadable transformer found under {variant}")
110
+
111
+ variant_path = str(variant)
112
+ id2label, id2label_cn = cls._load_labels_for_variant(variant_path)
113
+
114
+ pipe = cls(
115
+ transformer=transformer,
116
+ scheduler=scheduler,
117
+ id2label=id2label,
118
+ id2label_cn=id2label_cn,
119
+ )
120
+ if variant_path and hasattr(pipe, "register_to_config"):
121
+ pipe.register_to_config(_name_or_path=variant_path)
122
+ return pipe
123
+ finally:
124
+ for comp_path in inserted:
125
+ if comp_path in sys.path:
126
+ sys.path.remove(comp_path)
127
+
128
+ def __init__(
129
+ self,
130
+ transformer,
131
+ scheduler,
132
+ id2label: Optional[Dict[int, str]] = None,
133
+ id2label_cn: Optional[Dict[int, str]] = None,
134
+ ):
135
+ super().__init__()
136
+ self.register_modules(transformer=transformer, scheduler=scheduler)
137
+
138
+ self._id2label = id2label or {}
139
+ self._id2label_cn = id2label_cn or {}
140
+ self.labels = self._build_label2id(self._id2label)
141
+ self.labels_cn = self._build_label2id(self._id2label_cn)
142
+
143
+ def _ensure_labels_loaded(self) -> None:
144
+ if self._id2label or self._id2label_cn:
145
+ return
146
+ loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None))
147
+ if loaded_en:
148
+ self._id2label = loaded_en
149
+ self.labels = self._build_label2id(self._id2label)
150
+ if loaded_cn:
151
+ self._id2label_cn = loaded_cn
152
+ self.labels_cn = self._build_label2id(self._id2label_cn)
153
+
154
+ @staticmethod
155
+ def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]:
156
+ if not variant_path:
157
+ return None
158
+ variant_dir = Path(variant_path).resolve()
159
+ labels_dir = variant_dir.parent / "labels"
160
+ return labels_dir if labels_dir.is_dir() else None
161
+
162
+ @staticmethod
163
+ def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]:
164
+ filename = "id2label_en.json" if lang == "en" else "id2label_cn.json"
165
+ path = labels_dir / filename
166
+ if not path.exists():
167
+ raise FileNotFoundError(path)
168
+ raw = json.loads(path.read_text(encoding="utf-8"))
169
+ return {int(key): value for key, value in raw.items()}
170
+
171
+ @classmethod
172
+ def _load_labels_for_variant(
173
+ cls,
174
+ variant_path: Optional[str],
175
+ ) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]:
176
+ labels_dir = cls._labels_dir_for_variant(variant_path)
177
+ if labels_dir is None:
178
+ return None, None
179
+ try:
180
+ return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn")
181
+ except FileNotFoundError:
182
+ return None, None
183
+
184
+ @staticmethod
185
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
186
+ label2id: Dict[str, int] = {}
187
+ for class_id, value in id2label.items():
188
+ for synonym in value.split(","):
189
+ synonym = synonym.strip()
190
+ if synonym:
191
+ label2id[synonym] = int(class_id)
192
+ return dict(sorted(label2id.items()))
193
+
194
+ @property
195
+ def id2label(self) -> Dict[int, str]:
196
+ """ImageNet class id to English label string (comma-separated synonyms)."""
197
+ self._ensure_labels_loaded()
198
+ return self._id2label
199
+
200
+ @property
201
+ def id2label_cn(self) -> Dict[int, str]:
202
+ """ImageNet class id to Chinese label string (comma-separated synonyms)."""
203
+ self._ensure_labels_loaded()
204
+ return self._id2label_cn
205
+
206
+ def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]:
207
+ r"""
208
+ Map ImageNet label strings to class ids.
209
+
210
+ Args:
211
+ label (`str` or `list[str]`):
212
+ One or more label strings. Each string must match a synonym in `id2label` (English)
213
+ or `id2label_cn` (Chinese).
214
+ lang (`str`, *optional*, defaults to `"en"`):
215
+ `"en"` uses English synonyms; `"cn"` uses Chinese synonyms.
216
+ """
217
+ if lang not in ("en", "cn"):
218
+ raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.")
219
+
220
+ self._ensure_labels_loaded()
221
+ label2id = self.labels if lang == "en" else self.labels_cn
222
+ if not label2id:
223
+ raise ValueError(
224
+ f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
225
+ )
226
+
227
+ if isinstance(label, str):
228
+ label = [label]
229
+
230
+ missing = [item for item in label if item not in label2id]
231
+ if missing:
232
+ preview = ", ".join(list(label2id.keys())[:8])
233
+ raise ValueError(
234
+ f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
235
+ )
236
+ return [label2id[item] for item in label]
237
+
238
+ def _normalize_class_labels(
239
+ self,
240
+ class_labels: Union[int, str, List[Union[int, str]]],
241
+ ) -> List[int]:
242
+ if isinstance(class_labels, int):
243
+ return [class_labels]
244
+
245
+ if isinstance(class_labels, str):
246
+ return self.get_label_ids(class_labels)
247
+
248
+ if class_labels and isinstance(class_labels[0], str):
249
+ self._ensure_labels_loaded()
250
+ if all(label in self.labels for label in class_labels):
251
+ return self.get_label_ids(class_labels, lang="en")
252
+ if all(label in self.labels_cn for label in class_labels):
253
+ return self.get_label_ids(class_labels, lang="cn")
254
+ raise ValueError(
255
+ "Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` "
256
+ "or Chinese synonyms from `pipe.labels_cn`."
257
+ )
258
+
259
+ return list(class_labels)
260
+
261
+ def _predict_velocity(
262
+ self,
263
+ z_value: torch.Tensor,
264
+ t: torch.Tensor,
265
+ class_labels: torch.Tensor,
266
+ class_null: torch.Tensor,
267
+ do_classifier_free_guidance: bool,
268
+ guidance_scale: float,
269
+ guidance_interval_min: float,
270
+ guidance_interval_max: float,
271
+ ) -> torch.Tensor:
272
+ t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype)
273
+ if do_classifier_free_guidance:
274
+ z_in = torch.cat([z_value, z_value], dim=0)
275
+ labels = torch.cat([class_labels, class_null], dim=0)
276
+ else:
277
+ z_in = z_value
278
+ labels = class_labels
279
+
280
+ t_batch = t.flatten().expand(z_in.shape[0])
281
+ x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample
282
+ v = self.scheduler.velocity_from_prediction(z_in, x_pred, t)
283
+
284
+ if not do_classifier_free_guidance:
285
+ return v
286
+
287
+ v_cond, v_uncond = v.chunk(2, dim=0)
288
+ interval_mask = t < guidance_interval_max
289
+ if guidance_interval_min != 0.0:
290
+ interval_mask = interval_mask & (t > guidance_interval_min)
291
+ scale = torch.where(
292
+ interval_mask,
293
+ torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype),
294
+ torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype),
295
+ )
296
+ return v_uncond + scale * (v_cond - v_uncond)
297
+
298
+ def _run_sampler(
299
+ self,
300
+ latents: torch.Tensor,
301
+ class_labels: torch.Tensor,
302
+ class_null: torch.Tensor,
303
+ num_inference_steps: int,
304
+ do_classifier_free_guidance: bool,
305
+ guidance_scale: float,
306
+ guidance_interval_min: float,
307
+ guidance_interval_max: float,
308
+ sampling_method: str,
309
+ ) -> torch.Tensor:
310
+ device = latents.device
311
+ self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method)
312
+ timesteps = self.scheduler.timesteps
313
+
314
+ for i in self.progress_bar(range(num_inference_steps - 1)):
315
+ t = timesteps[i]
316
+ t_next = timesteps[i + 1]
317
+ v = self._predict_velocity(
318
+ latents,
319
+ t,
320
+ class_labels,
321
+ class_null,
322
+ do_classifier_free_guidance,
323
+ guidance_scale,
324
+ guidance_interval_min,
325
+ guidance_interval_max,
326
+ )
327
+
328
+ if sampling_method == "heun":
329
+ latents_euler = latents + (t_next - t) * v
330
+ v_next = self._predict_velocity(
331
+ latents_euler,
332
+ t_next,
333
+ class_labels,
334
+ class_null,
335
+ do_classifier_free_guidance,
336
+ guidance_scale,
337
+ guidance_interval_min,
338
+ guidance_interval_max,
339
+ )
340
+ latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample
341
+ else:
342
+ latents = self.scheduler.step(v, t, latents).prev_sample
343
+
344
+ t = timesteps[-2]
345
+ t_next = timesteps[-1]
346
+ v = self._predict_velocity(
347
+ latents,
348
+ t,
349
+ class_labels,
350
+ class_null,
351
+ do_classifier_free_guidance,
352
+ guidance_scale,
353
+ guidance_interval_min,
354
+ guidance_interval_max,
355
+ )
356
+ return latents + (t_next - t) * v
357
+
358
+ @torch.inference_mode()
359
+ def __call__(
360
+ self,
361
+ class_labels: Union[int, str, List[Union[int, str]]],
362
+ guidance_scale: Optional[float] = None,
363
+ guidance_interval_min: float = 0.1,
364
+ guidance_interval_max: float = 1.0,
365
+ noise_scale: Optional[float] = None,
366
+ t_eps: Optional[float] = None,
367
+ sampling_method: Optional[str] = None,
368
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
369
+ num_inference_steps: int = 50,
370
+ output_type: Optional[str] = "pil",
371
+ return_dict: bool = True,
372
+ ) -> Union[ImagePipelineOutput, Tuple]:
373
+ r"""
374
+ Generate class-conditional images.
375
+
376
+ Args:
377
+ class_labels (`int`, `str`, `list[int]`, or `list[str]`):
378
+ ImageNet class indices or human-readable label strings (English or Chinese).
379
+ guidance_scale (`float`, *optional*):
380
+ Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
381
+ guidance_interval_min (`float`, defaults to `0.1`):
382
+ Lower bound of the CFG interval in flow time `t in [0, 1]`.
383
+ guidance_interval_max (`float`, defaults to `1.0`):
384
+ Upper bound of the CFG interval in flow time.
385
+ noise_scale (`float`, *optional*):
386
+ Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
387
+ t_eps (`float`, *optional*):
388
+ Epsilon clamp for the `1 - t` denominator (scheduler config by default).
389
+ sampling_method (`str`, *optional*):
390
+ `"heun"` or `"euler"`. Defaults to the scheduler config (`heun`).
391
+ generator (`torch.Generator`, *optional*):
392
+ RNG for reproducibility.
393
+ num_inference_steps (`int`, defaults to `50`):
394
+ Number of solver steps (at least 2).
395
+ output_type (`str`, *optional*, defaults to `"pil"`):
396
+ `"pil"`, `"np"`, or `"pt"`.
397
+ return_dict (`bool`, *optional*, defaults to `True`):
398
+ Return [`ImagePipelineOutput`] if True.
399
+ """
400
+ solver = sampling_method or self.scheduler.config.solver
401
+ if solver not in {"heun", "euler"}:
402
+ raise ValueError("sampling_method must be one of: 'heun', 'euler'.")
403
+ if num_inference_steps < 2:
404
+ raise ValueError("num_inference_steps must be >= 2.")
405
+
406
+ if t_eps is not None:
407
+ self.scheduler.register_to_config(t_eps=t_eps)
408
+
409
+ class_label_ids = self._normalize_class_labels(class_labels)
410
+ do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
411
+
412
+ batch_size = len(class_label_ids)
413
+ image_size = int(self.transformer.config.sample_size)
414
+ channels = int(self.transformer.config.in_channels)
415
+ null_class_val = int(self.transformer.config.num_classes)
416
+
417
+ if guidance_scale is None:
418
+ guidance_scale = 1.0
419
+ if noise_scale is None:
420
+ noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0)
421
+
422
+ latents = (
423
+ randn_tensor(
424
+ shape=(batch_size, channels, image_size, image_size),
425
+ generator=generator,
426
+ device=self._execution_device,
427
+ dtype=self.transformer.dtype,
428
+ )
429
+ * noise_scale
430
+ )
431
+
432
+ class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
433
+ class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
434
+ class_null = torch.full_like(class_labels_t, null_class_val)
435
+
436
+ latents = self._run_sampler(
437
+ latents,
438
+ class_labels_t,
439
+ class_null,
440
+ num_inference_steps,
441
+ do_classifier_free_guidance,
442
+ guidance_scale,
443
+ guidance_interval_min,
444
+ guidance_interval_max,
445
+ solver,
446
+ )
447
+
448
+ images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
449
+ if output_type == "pt":
450
+ images = images_pt
451
+ elif output_type == "np":
452
+ images = images_pt.permute(0, 2, 3, 1).numpy()
453
+ else:
454
+ images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy())
455
+
456
+ self.maybe_free_model_hooks()
457
+
458
+ if not return_dict:
459
+ return (images,)
460
+ return ImagePipelineOutput(images=images)
JiT-B-32/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "JiTScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "t_eps": 0.05,
6
+ "solver": "heun"
7
+ }
JiT-B-32/scheduler/scheduling_jit.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
22
+ from diffusers.utils import BaseOutput
23
+
24
+
25
+ @dataclass
26
+ class JiTSchedulerOutput(BaseOutput):
27
+ """
28
+ Output class for the JiT scheduler's `step` function.
29
+
30
+ Args:
31
+ prev_sample (`torch.Tensor`):
32
+ Updated sample after one solver step along the JiT flow-time grid.
33
+ """
34
+
35
+ prev_sample: torch.Tensor
36
+
37
+
38
+ class JiTScheduler(SchedulerMixin, ConfigMixin):
39
+ """
40
+ Manual flow-matching scheduler for JiT checkpoints.
41
+
42
+ Uses a linear flow-time grid `t in [0, 1]` (increasing), matching the official JiT
43
+ sampler. Velocity is `v = (x_pred - z) / (1 - t)`; integration is explicit Euler or
44
+ Heun along that grid.
45
+ """
46
+
47
+ order = 2
48
+
49
+ @register_to_config
50
+ def __init__(
51
+ self,
52
+ num_train_timesteps: int = 1000,
53
+ t_eps: float = 5e-2,
54
+ solver: str = "heun",
55
+ ):
56
+ if solver not in {"heun", "euler"}:
57
+ raise ValueError("solver must be one of: 'heun', 'euler'.")
58
+ self.timesteps: Optional[torch.Tensor] = None
59
+ self.sigmas: Optional[List[float]] = None
60
+ self.num_inference_steps: Optional[int] = None
61
+ self._step_index: Optional[int] = None
62
+
63
+ @property
64
+ def init_noise_sigma(self) -> float:
65
+ return 1.0
66
+
67
+ def set_timesteps(
68
+ self,
69
+ num_inference_steps: int,
70
+ device: Union[str, torch.device, None] = None,
71
+ solver: Optional[str] = None,
72
+ ) -> None:
73
+ if num_inference_steps < 2:
74
+ raise ValueError("num_inference_steps must be >= 2.")
75
+
76
+ self.num_inference_steps = num_inference_steps
77
+ self.timesteps = torch.linspace(
78
+ 0.0,
79
+ 1.0,
80
+ num_inference_steps + 1,
81
+ device=device,
82
+ dtype=torch.float32,
83
+ )
84
+ sigma_grid = torch.linspace(0.0, 1.0, num_inference_steps, device=device, dtype=torch.float32)
85
+ self.sigmas = (1.0 - sigma_grid).tolist()
86
+ self._step_index = 0
87
+ if solver is not None:
88
+ self.register_to_config(solver=solver)
89
+
90
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
91
+ del timestep
92
+ return sample
93
+
94
+ def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int:
95
+ if self._step_index is not None:
96
+ return self._step_index
97
+ if self.timesteps is None:
98
+ raise ValueError("Call `set_timesteps` before `step`.")
99
+ if timestep is None:
100
+ return 0
101
+ t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0])
102
+ matches = (self.timesteps - t_value).abs() < 1e-6
103
+ if matches.any():
104
+ return int(matches.nonzero(as_tuple=False)[0].item())
105
+ return 0
106
+
107
+ def step(
108
+ self,
109
+ model_output: torch.Tensor,
110
+ timestep: Union[float, torch.Tensor, None],
111
+ sample: torch.Tensor,
112
+ model_output_next: Optional[torch.Tensor] = None,
113
+ return_dict: bool = True,
114
+ ) -> Union[JiTSchedulerOutput, Tuple[torch.Tensor]]:
115
+ """
116
+ Integrate one step on the linear `t` grid.
117
+
118
+ Args:
119
+ model_output (`torch.Tensor`):
120
+ Velocity `v = (x_pred - z) / (1 - t)` at the current time.
121
+ timestep (`float` or `torch.Tensor`, *optional*):
122
+ Current flow time `t`. When omitted, uses the internal step index.
123
+ sample (`torch.Tensor`):
124
+ Current noisy latent `z`.
125
+ model_output_next (`torch.Tensor`, *optional*):
126
+ Velocity at `t_next` (required for Heun intermediate steps).
127
+ """
128
+ if self.timesteps is None:
129
+ raise ValueError("Call `set_timesteps` before `step`.")
130
+
131
+ step_index = self._resolve_step_index(timestep)
132
+ if step_index >= len(self.timesteps) - 1:
133
+ raise ValueError("Scheduler has already reached the final timestep.")
134
+
135
+ t = self.timesteps[step_index]
136
+ t_next = self.timesteps[step_index + 1]
137
+ dt = t_next - t
138
+
139
+ if self.config.solver == "heun" and model_output_next is not None:
140
+ prev_sample = sample + dt * 0.5 * (model_output + model_output_next)
141
+ else:
142
+ prev_sample = sample + dt * model_output
143
+
144
+ self._step_index = step_index + 1
145
+
146
+ if not return_dict:
147
+ return (prev_sample,)
148
+ return JiTSchedulerOutput(prev_sample=prev_sample)
149
+
150
+ def velocity_from_prediction(
151
+ self,
152
+ sample: torch.Tensor,
153
+ x_pred: torch.Tensor,
154
+ timestep: Union[float, torch.Tensor],
155
+ ) -> torch.Tensor:
156
+ """Compute JiT velocity `v = (x_pred - z) / (1 - t)` with denominator clamp."""
157
+ t = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype)
158
+ while t.ndim < sample.ndim:
159
+ t = t.unsqueeze(-1)
160
+ denom = (1.0 - t).clamp_min(self.config.t_eps)
161
+ return (x_pred - sample) / denom
JiT-B-32/transformer/config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "JiTTransformer2DModel",
3
+ "_diffusers_version": "0.36.0",
4
+ "attention_dropout": 0.0,
5
+ "bottleneck_dim": 128,
6
+ "dropout": 0.0,
7
+ "hidden_size": 768,
8
+ "in_channels": 3,
9
+ "in_context_len": 32,
10
+ "in_context_start": 4,
11
+ "mlp_ratio": 4.0,
12
+ "norm_eps": 1e-06,
13
+ "num_attention_heads": 12,
14
+ "num_classes": 1000,
15
+ "num_layers": 12,
16
+ "patch_size": 32,
17
+ "sample_size": 512
18
+ }
JiT-B-32/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:729654b3302fdae22eb4a4de9d2b24545828c82f2e2c8dcd3f5a01fe7c606ba4
3
+ size 533565560
JiT-B-32/transformer/jit_transformer_2d.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
24
+ from diffusers.models.modeling_utils import ModelMixin
25
+ from diffusers.models.normalization import RMSNorm
26
+ from diffusers.utils import logging
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+
32
+ def broadcat(tensors, dim=-1):
33
+ num_tensors = len(tensors)
34
+ shape_lens = {len(t.shape) for t in tensors}
35
+ if len(shape_lens) != 1:
36
+ raise ValueError("tensors must all have the same number of dimensions")
37
+ shape_len = list(shape_lens)[0]
38
+ dim = (dim + shape_len) if dim < 0 else dim
39
+ dims = list(zip(*(list(t.shape) for t in tensors)))
40
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
41
+
42
+ if not all(len(set(t[1])) <= 2 for t in expandable_dims):
43
+ raise ValueError("invalid dimensions for broadcastable concatenation")
44
+
45
+ max_dims = [(t[0], max(t[1])) for t in expandable_dims]
46
+ expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
47
+ expanded_dims.insert(dim, (dim, dims[dim]))
48
+ expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
49
+ tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
50
+ return torch.cat(tensors, dim=dim)
51
+
52
+
53
+ def rotate_half(x):
54
+ x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2)
55
+ x1, x2 = x.unbind(dim=-1)
56
+ x = torch.stack((-x2, x1), dim=-1)
57
+ return x.view(*x.shape[:-2], -1)
58
+
59
+
60
+ class JiTRotaryEmbedding(nn.Module):
61
+ def __init__(
62
+ self,
63
+ dim,
64
+ pt_seq_len=16,
65
+ ft_seq_len=None,
66
+ custom_freqs=None,
67
+ theta=10000,
68
+ num_cls_token=0,
69
+ ):
70
+ super().__init__()
71
+ if custom_freqs is not None:
72
+ freqs = custom_freqs
73
+ else:
74
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
75
+
76
+ if ft_seq_len is None:
77
+ ft_seq_len = pt_seq_len
78
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
79
+
80
+ freqs = torch.einsum("..., f -> ... f", t, freqs)
81
+ freqs = freqs.repeat_interleave(2, dim=-1)
82
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
83
+
84
+ if num_cls_token > 0:
85
+ freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D]
86
+ cos_img = freqs_flat.cos()
87
+ sin_img = freqs_flat.sin()
88
+
89
+ # prepend in-context cls token
90
+ _, D = cos_img.shape
91
+ cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype)
92
+ sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype)
93
+
94
+ self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False)
95
+ self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False)
96
+ else:
97
+ self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False)
98
+ self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False)
99
+
100
+ def forward(self, t):
101
+ # Applied on (batch, seq_len, heads, head_dim) tensors from attention.
102
+ seq_len = t.shape[1]
103
+ freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
104
+ freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
105
+
106
+ return t * freqs_cos[:, None, :] + rotate_half(t) * freqs_sin[:, None, :]
107
+
108
+
109
+ def modulate(x, shift, scale):
110
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
111
+
112
+
113
+ class JiTPatchEmbed(nn.Module):
114
+ """Image to Patch Embedding with Bottleneck"""
115
+
116
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True):
117
+ super().__init__()
118
+ img_size = (img_size, img_size)
119
+ patch_size = (patch_size, patch_size)
120
+ self.img_size = img_size
121
+ self.patch_size = patch_size
122
+ self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
123
+
124
+ self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False)
125
+ self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias)
126
+
127
+ def forward(self, x):
128
+ x = self.proj2(self.proj1(x)).flatten(2).transpose(1, 2)
129
+ return x
130
+
131
+
132
+ class JiTTimestepEmbedder(nn.Module):
133
+ """
134
+ Embeds scalar timesteps into vector representations.
135
+ """
136
+
137
+ def __init__(self, hidden_size, frequency_embedding_size=256):
138
+ super().__init__()
139
+ self.mlp = nn.Sequential(
140
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
141
+ nn.SiLU(),
142
+ nn.Linear(hidden_size, hidden_size, bias=True),
143
+ )
144
+ self.frequency_embedding_size = frequency_embedding_size
145
+
146
+ @staticmethod
147
+ def timestep_embedding(t, dim, max_period=10000):
148
+ """
149
+ Create sinusoidal timestep embeddings.
150
+ """
151
+ half = dim // 2
152
+ freqs = torch.exp(
153
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
154
+ ).to(device=t.device)
155
+ args = t[:, None].float() * freqs[None]
156
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
157
+ if dim % 2:
158
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
159
+ return embedding
160
+
161
+ def forward(self, t, dtype=None):
162
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
163
+ if dtype is not None:
164
+ t_freq = t_freq.to(dtype=dtype)
165
+ t_emb = self.mlp(t_freq)
166
+ return t_emb
167
+
168
+
169
+ class JiTLabelEmbedder(nn.Module):
170
+ """
171
+ Embeds class labels into vector representations.
172
+ """
173
+
174
+ def __init__(self, num_classes, hidden_size):
175
+ super().__init__()
176
+ self.embedding_table = nn.Embedding(num_classes + 1, hidden_size)
177
+ self.num_classes = num_classes
178
+
179
+ def forward(self, labels):
180
+ embeddings = self.embedding_table(labels)
181
+ return embeddings
182
+
183
+
184
+ class JiTAttention(nn.Module):
185
+ def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
186
+ super().__init__()
187
+ self.num_heads = num_heads
188
+ head_dim = dim // num_heads
189
+
190
+ self.q_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
191
+ self.k_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
192
+
193
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
194
+ self.attn_drop = attn_drop
195
+ self.proj = nn.Linear(dim, dim)
196
+ self.proj_drop = nn.Dropout(proj_drop)
197
+
198
+ def forward(self, x, rope=None):
199
+ B, N, C = x.shape
200
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
201
+ q, k, v = qkv[0], qkv[1], qkv[2]
202
+
203
+ q = self.q_norm(q)
204
+ k = self.k_norm(k)
205
+
206
+ if rope is not None:
207
+ q = q.transpose(1, 2)
208
+ k = k.transpose(1, 2)
209
+ q = rope(q)
210
+ k = rope(k)
211
+ q = q.transpose(1, 2)
212
+ k = k.transpose(1, 2)
213
+
214
+ dropout_p = self.attn_drop if self.training else 0.0
215
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
216
+ x = x.transpose(1, 2).reshape(B, N, C)
217
+ x = self.proj(x)
218
+ x = self.proj_drop(x)
219
+ return x
220
+
221
+
222
+ class JiTSwiGLUFFN(nn.Module):
223
+ def __init__(self, dim: int, hidden_dim: int, drop=0.0, bias=True) -> None:
224
+ super().__init__()
225
+ hidden_dim = int(hidden_dim * 2 / 3)
226
+ self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias)
227
+ self.w3 = nn.Linear(hidden_dim, dim, bias=bias)
228
+ self.ffn_dropout = nn.Dropout(drop)
229
+
230
+ def forward(self, x):
231
+ x12 = self.w12(x)
232
+ x1, x2 = x12.chunk(2, dim=-1)
233
+ hidden = F.silu(x1) * x2
234
+ return self.w3(self.ffn_dropout(hidden))
235
+
236
+
237
+ class JiTBlock(nn.Module):
238
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
239
+ super().__init__()
240
+ self.norm1 = RMSNorm(hidden_size, eps=eps)
241
+ self.attn = JiTAttention(
242
+ hidden_size,
243
+ num_heads=num_heads,
244
+ qkv_bias=True,
245
+ qk_norm=True,
246
+ attn_drop=attn_drop,
247
+ proj_drop=proj_drop,
248
+ eps=eps,
249
+ )
250
+ self.norm2 = RMSNorm(hidden_size, eps=eps)
251
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
252
+ self.mlp = JiTSwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop)
253
+
254
+ self.act = nn.SiLU()
255
+ self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
256
+
257
+ def forward(self, x, c, feat_rope=None):
258
+ # Apply activation
259
+ c = self.act(c)
260
+
261
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
262
+
263
+ # Attention block
264
+ norm_x = self.norm1(x)
265
+ modulated_x = modulate(norm_x, shift_msa, scale_msa)
266
+ attn_out = self.attn(modulated_x, rope=feat_rope)
267
+ x = x + gate_msa.unsqueeze(1) * attn_out
268
+
269
+ # MLP block
270
+ norm_x = self.norm2(x)
271
+ modulated_x = modulate(norm_x, shift_mlp, scale_mlp)
272
+ mlp_out = self.mlp(modulated_x)
273
+ x = x + gate_mlp.unsqueeze(1) * mlp_out
274
+
275
+ return x
276
+
277
+
278
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
279
+ grid_h = np.arange(grid_size, dtype=np.float32)
280
+ grid_w = np.arange(grid_size, dtype=np.float32)
281
+ grid = np.meshgrid(grid_w, grid_h)
282
+ grid = np.stack(grid, axis=0)
283
+ grid = grid.reshape([2, 1, grid_size, grid_size])
284
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
285
+ if cls_token and extra_tokens > 0:
286
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
287
+ return pos_embed
288
+
289
+
290
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
291
+ if embed_dim % 2 != 0:
292
+ raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
293
+
294
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
295
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
296
+ emb = np.concatenate([emb_h, emb_w], axis=1)
297
+ return emb
298
+
299
+
300
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
301
+ if embed_dim % 2 != 0:
302
+ raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
303
+
304
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
305
+ omega /= embed_dim / 2.0
306
+ omega = 1.0 / 10000**omega
307
+
308
+ pos = pos.reshape(-1)
309
+ out = np.einsum("m,d->md", pos, omega)
310
+
311
+ emb_sin = np.sin(out)
312
+ emb_cos = np.cos(out)
313
+
314
+ emb = np.concatenate([emb_sin, emb_cos], axis=1)
315
+ return emb
316
+
317
+
318
+ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
319
+ r"""
320
+ A 2D Transformer for pixel-space class-conditional generation with JiT
321
+ ([Back to Basics: Let Denoising Generative Models Denoise](https://arxiv.org/abs/2511.13720)).
322
+
323
+ Parameters:
324
+ sample_size (`int`, defaults to `256`):
325
+ Input image resolution (height and width).
326
+ patch_size (`int`, defaults to `16`):
327
+ Patch size for the bottleneck patch embedder.
328
+ in_channels (`int`, defaults to `3`):
329
+ Number of input image channels.
330
+ hidden_size (`int`, defaults to `768`):
331
+ Transformer hidden dimension.
332
+ num_layers (`int`, defaults to `12`):
333
+ Number of JiT transformer blocks.
334
+ num_attention_heads (`int`, defaults to `12`):
335
+ Number of attention heads per block.
336
+ mlp_ratio (`float`, defaults to `4.0`):
337
+ MLP hidden dimension multiplier.
338
+ attention_dropout (`float`, defaults to `0.0`):
339
+ Attention dropout in the middle quarter of blocks.
340
+ dropout (`float`, defaults to `0.0`):
341
+ Projection dropout in the middle quarter of blocks.
342
+ num_classes (`int`, defaults to `1000`):
343
+ Number of class labels (null label uses index `num_classes` for CFG).
344
+ bottleneck_dim (`int`, defaults to `128`):
345
+ PCA bottleneck dimension in the patch embedder.
346
+ in_context_len (`int`, defaults to `32`):
347
+ Number of in-context class tokens prepended mid-network.
348
+ in_context_start (`int`, defaults to `4`):
349
+ Block index at which in-context tokens are inserted.
350
+ norm_eps (`float`, defaults to `1e-6`):
351
+ Epsilon for RMSNorm layers.
352
+ """
353
+
354
+ _supports_gradient_checkpointing = True
355
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
356
+
357
+ @register_to_config
358
+ def __init__(
359
+ self,
360
+ sample_size: int = 256,
361
+ patch_size: int = 16,
362
+ in_channels: int = 3,
363
+ hidden_size: int = 768,
364
+ num_layers: int = 12,
365
+ num_attention_heads: int = 12,
366
+ mlp_ratio: float = 4.0,
367
+ attention_dropout: float = 0.0,
368
+ dropout: float = 0.0,
369
+ num_classes: int = 1000,
370
+ bottleneck_dim: int = 128,
371
+ in_context_len: int = 32,
372
+ in_context_start: int = 4,
373
+ norm_eps: float = 1e-6,
374
+ ):
375
+ super().__init__()
376
+ self.sample_size = sample_size
377
+ self.patch_size = patch_size
378
+ self.in_channels = in_channels
379
+ self.out_channels = in_channels
380
+ self.hidden_size = hidden_size
381
+ self.num_layers = num_layers
382
+ self.num_attention_heads = num_attention_heads
383
+ self.in_context_len = in_context_len
384
+ self.in_context_start = in_context_start
385
+ self.norm_eps = norm_eps
386
+ self.gradient_checkpointing = False
387
+
388
+ # Time and Class Embedding
389
+ self.t_embedder = JiTTimestepEmbedder(hidden_size)
390
+ self.y_embedder = JiTLabelEmbedder(num_classes, hidden_size)
391
+
392
+ # Patch Embedding
393
+ self.x_embedder = JiTPatchEmbed(
394
+ img_size=sample_size,
395
+ patch_size=patch_size,
396
+ in_chans=in_channels,
397
+ pca_dim=bottleneck_dim,
398
+ embed_dim=hidden_size,
399
+ bias=True,
400
+ )
401
+
402
+ # Positional Embedding (Fixed Sin-Cos)
403
+ num_patches = self.x_embedder.num_patches
404
+ pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches**0.5))
405
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
406
+
407
+ # In-context Embedding
408
+ if self.in_context_len > 0:
409
+ self.in_context_posemb = nn.Parameter(torch.zeros(1, self.in_context_len, hidden_size))
410
+
411
+ # RoPE
412
+ half_head_dim = hidden_size // num_attention_heads // 2
413
+ hw_seq_len = sample_size // patch_size
414
+ self.feat_rope = JiTRotaryEmbedding(dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=0)
415
+ self.feat_rope_incontext = JiTRotaryEmbedding(
416
+ dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=self.in_context_len
417
+ )
418
+
419
+ # Blocks
420
+ self.blocks = nn.ModuleList(
421
+ [
422
+ JiTBlock(
423
+ hidden_size,
424
+ num_attention_heads,
425
+ mlp_ratio=mlp_ratio,
426
+ attn_drop=attention_dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
427
+ proj_drop=dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
428
+ eps=norm_eps,
429
+ )
430
+ for i in range(num_layers)
431
+ ]
432
+ )
433
+
434
+ # Final Layer
435
+ self.norm_final = RMSNorm(hidden_size, eps=norm_eps)
436
+ self.linear_final = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
437
+ self.act_final = nn.SiLU()
438
+ self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
439
+
440
+ def forward(
441
+ self,
442
+ hidden_states: torch.Tensor,
443
+ timestep: torch.LongTensor,
444
+ class_labels: torch.LongTensor,
445
+ return_dict: bool = True,
446
+ ):
447
+
448
+ t_emb = self.t_embedder(timestep, dtype=hidden_states.dtype)
449
+ y_emb = self.y_embedder(class_labels)
450
+
451
+ # Ensure embeddings match hidden_states dtype
452
+ y_emb = y_emb.to(dtype=hidden_states.dtype)
453
+
454
+ c = t_emb + y_emb
455
+
456
+ # Patch Embed
457
+ x = self.x_embedder(hidden_states)
458
+ x = x + self.pos_embed.to(x.dtype)
459
+
460
+ # Blocks
461
+ for i, block in enumerate(self.blocks):
462
+ if self.in_context_len > 0 and i == self.in_context_start:
463
+ in_context_tokens = y_emb.unsqueeze(1).repeat(1, self.in_context_len, 1)
464
+ in_context_tokens = in_context_tokens + self.in_context_posemb.to(in_context_tokens.dtype)
465
+ x = torch.cat([in_context_tokens, x], dim=1)
466
+
467
+ rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
468
+
469
+ if self.training and self.gradient_checkpointing:
470
+ x = torch.utils.checkpoint.checkpoint(
471
+ block,
472
+ x,
473
+ c,
474
+ rope,
475
+ use_reentrant=False,
476
+ )
477
+ else:
478
+ x = block(x, c, feat_rope=rope)
479
+
480
+ # Slice off in-context tokens
481
+ if self.in_context_len > 0:
482
+ x = x[:, self.in_context_len :]
483
+
484
+ # Final Layer
485
+ c = self.act_final(c)
486
+ shift, scale = self.adaLN_modulation_final(c).chunk(2, dim=1)
487
+
488
+ x = modulate(self.norm_final(x), shift, scale)
489
+ x = self.linear_final(x)
490
+
491
+ # Unpatchify
492
+ h = w = int(x.shape[1] ** 0.5)
493
+ x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels))
494
+ x = torch.einsum("nhwpqc->nchpwq", x)
495
+ output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size))
496
+
497
+ if not return_dict:
498
+ return (output,)
499
+
500
+ return Transformer2DModelOutput(sample=output)
JiT-H-16/model_index.json CHANGED
@@ -1,8 +1,15 @@
1
  {
2
- "_class_name": "JiTPipeline",
 
 
 
3
  "_diffusers_version": "0.36.0",
 
 
 
 
4
  "transformer": [
5
- "jit_diffusers",
6
  "JiTTransformer2DModel"
7
  ]
8
  }
 
1
  {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "JiTPipeline"
5
+ ],
6
  "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "scheduling_jit",
9
+ "JiTScheduler"
10
+ ],
11
  "transformer": [
12
+ "jit_transformer_2d",
13
  "JiTTransformer2DModel"
14
  ]
15
  }
JiT-H-16/pipeline.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import importlib
18
+ import json
19
+ import sys
20
+ from pathlib import Path
21
+ from typing import Dict, List, Optional, Tuple, Union
22
+
23
+ import torch
24
+
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
26
+ from diffusers.utils.torch_utils import randn_tensor
27
+
28
+
29
+ RECOMMENDED_NOISE_BY_SIZE = {
30
+ 256: 1.0,
31
+ 512: 2.0,
32
+ }
33
+
34
+
35
+ class JiTPipeline(DiffusionPipeline):
36
+ r"""
37
+ Pipeline for image generation using JiT (Just image Transformer).
38
+
39
+ Parameters:
40
+ transformer ([`JiTTransformer2DModel`]):
41
+ A class-conditioned `JiTTransformer2DModel` to denoise the images.
42
+ scheduler ([`JiTScheduler`]):
43
+ Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler).
44
+ id2label (`dict[int, str]`, *optional*):
45
+ ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
46
+ id2label_cn (`dict[int, str]`, *optional*):
47
+ ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms.
48
+ """
49
+
50
+ model_cpu_offload_seq = "transformer"
51
+
52
+ @classmethod
53
+ def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
54
+ """Load a self-contained variant folder locally or from the Hub.
55
+
56
+ Examples:
57
+ JiTPipeline.from_pretrained(".")
58
+ JiTPipeline.from_pretrained("./JiT-H-32")
59
+ DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
60
+ """
61
+ repo_root = Path(__file__).resolve().parent
62
+
63
+ if pretrained_model_name_or_path in (None, "", "."):
64
+ variant = repo_root
65
+ elif (
66
+ isinstance(pretrained_model_name_or_path, str)
67
+ and "/" in pretrained_model_name_or_path
68
+ and not Path(pretrained_model_name_or_path).exists()
69
+ ):
70
+ from huggingface_hub import snapshot_download
71
+
72
+ hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
73
+ if subfolder:
74
+ hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"])
75
+ cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
76
+ variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
77
+ else:
78
+ variant = Path(pretrained_model_name_or_path)
79
+ if not variant.is_absolute():
80
+ candidate = (Path.cwd() / variant).resolve()
81
+ variant = candidate if candidate.exists() else (repo_root / variant).resolve()
82
+ if subfolder:
83
+ variant = variant / subfolder
84
+
85
+ model_kwargs = dict(kwargs)
86
+ inserted: List[str] = []
87
+
88
+ def _load_component(folder: str, module_name: str, class_name: str):
89
+ comp_dir = variant / folder
90
+ module_path = comp_dir / f"{module_name}.py"
91
+ has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
92
+ if not module_path.exists() or not has_weights:
93
+ return None
94
+
95
+ comp_path = str(comp_dir)
96
+ if comp_path not in sys.path:
97
+ sys.path.insert(0, comp_path)
98
+ inserted.append(comp_path)
99
+
100
+ module = importlib.import_module(module_name)
101
+ component_cls = getattr(module, class_name)
102
+ return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
103
+
104
+ try:
105
+ transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
106
+ scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler")
107
+
108
+ if transformer is None:
109
+ raise ValueError(f"No loadable transformer found under {variant}")
110
+
111
+ variant_path = str(variant)
112
+ id2label, id2label_cn = cls._load_labels_for_variant(variant_path)
113
+
114
+ pipe = cls(
115
+ transformer=transformer,
116
+ scheduler=scheduler,
117
+ id2label=id2label,
118
+ id2label_cn=id2label_cn,
119
+ )
120
+ if variant_path and hasattr(pipe, "register_to_config"):
121
+ pipe.register_to_config(_name_or_path=variant_path)
122
+ return pipe
123
+ finally:
124
+ for comp_path in inserted:
125
+ if comp_path in sys.path:
126
+ sys.path.remove(comp_path)
127
+
128
+ def __init__(
129
+ self,
130
+ transformer,
131
+ scheduler,
132
+ id2label: Optional[Dict[int, str]] = None,
133
+ id2label_cn: Optional[Dict[int, str]] = None,
134
+ ):
135
+ super().__init__()
136
+ self.register_modules(transformer=transformer, scheduler=scheduler)
137
+
138
+ self._id2label = id2label or {}
139
+ self._id2label_cn = id2label_cn or {}
140
+ self.labels = self._build_label2id(self._id2label)
141
+ self.labels_cn = self._build_label2id(self._id2label_cn)
142
+
143
+ def _ensure_labels_loaded(self) -> None:
144
+ if self._id2label or self._id2label_cn:
145
+ return
146
+ loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None))
147
+ if loaded_en:
148
+ self._id2label = loaded_en
149
+ self.labels = self._build_label2id(self._id2label)
150
+ if loaded_cn:
151
+ self._id2label_cn = loaded_cn
152
+ self.labels_cn = self._build_label2id(self._id2label_cn)
153
+
154
+ @staticmethod
155
+ def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]:
156
+ if not variant_path:
157
+ return None
158
+ variant_dir = Path(variant_path).resolve()
159
+ labels_dir = variant_dir.parent / "labels"
160
+ return labels_dir if labels_dir.is_dir() else None
161
+
162
+ @staticmethod
163
+ def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]:
164
+ filename = "id2label_en.json" if lang == "en" else "id2label_cn.json"
165
+ path = labels_dir / filename
166
+ if not path.exists():
167
+ raise FileNotFoundError(path)
168
+ raw = json.loads(path.read_text(encoding="utf-8"))
169
+ return {int(key): value for key, value in raw.items()}
170
+
171
+ @classmethod
172
+ def _load_labels_for_variant(
173
+ cls,
174
+ variant_path: Optional[str],
175
+ ) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]:
176
+ labels_dir = cls._labels_dir_for_variant(variant_path)
177
+ if labels_dir is None:
178
+ return None, None
179
+ try:
180
+ return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn")
181
+ except FileNotFoundError:
182
+ return None, None
183
+
184
+ @staticmethod
185
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
186
+ label2id: Dict[str, int] = {}
187
+ for class_id, value in id2label.items():
188
+ for synonym in value.split(","):
189
+ synonym = synonym.strip()
190
+ if synonym:
191
+ label2id[synonym] = int(class_id)
192
+ return dict(sorted(label2id.items()))
193
+
194
+ @property
195
+ def id2label(self) -> Dict[int, str]:
196
+ """ImageNet class id to English label string (comma-separated synonyms)."""
197
+ self._ensure_labels_loaded()
198
+ return self._id2label
199
+
200
+ @property
201
+ def id2label_cn(self) -> Dict[int, str]:
202
+ """ImageNet class id to Chinese label string (comma-separated synonyms)."""
203
+ self._ensure_labels_loaded()
204
+ return self._id2label_cn
205
+
206
+ def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]:
207
+ r"""
208
+ Map ImageNet label strings to class ids.
209
+
210
+ Args:
211
+ label (`str` or `list[str]`):
212
+ One or more label strings. Each string must match a synonym in `id2label` (English)
213
+ or `id2label_cn` (Chinese).
214
+ lang (`str`, *optional*, defaults to `"en"`):
215
+ `"en"` uses English synonyms; `"cn"` uses Chinese synonyms.
216
+ """
217
+ if lang not in ("en", "cn"):
218
+ raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.")
219
+
220
+ self._ensure_labels_loaded()
221
+ label2id = self.labels if lang == "en" else self.labels_cn
222
+ if not label2id:
223
+ raise ValueError(
224
+ f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
225
+ )
226
+
227
+ if isinstance(label, str):
228
+ label = [label]
229
+
230
+ missing = [item for item in label if item not in label2id]
231
+ if missing:
232
+ preview = ", ".join(list(label2id.keys())[:8])
233
+ raise ValueError(
234
+ f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
235
+ )
236
+ return [label2id[item] for item in label]
237
+
238
+ def _normalize_class_labels(
239
+ self,
240
+ class_labels: Union[int, str, List[Union[int, str]]],
241
+ ) -> List[int]:
242
+ if isinstance(class_labels, int):
243
+ return [class_labels]
244
+
245
+ if isinstance(class_labels, str):
246
+ return self.get_label_ids(class_labels)
247
+
248
+ if class_labels and isinstance(class_labels[0], str):
249
+ self._ensure_labels_loaded()
250
+ if all(label in self.labels for label in class_labels):
251
+ return self.get_label_ids(class_labels, lang="en")
252
+ if all(label in self.labels_cn for label in class_labels):
253
+ return self.get_label_ids(class_labels, lang="cn")
254
+ raise ValueError(
255
+ "Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` "
256
+ "or Chinese synonyms from `pipe.labels_cn`."
257
+ )
258
+
259
+ return list(class_labels)
260
+
261
+ def _predict_velocity(
262
+ self,
263
+ z_value: torch.Tensor,
264
+ t: torch.Tensor,
265
+ class_labels: torch.Tensor,
266
+ class_null: torch.Tensor,
267
+ do_classifier_free_guidance: bool,
268
+ guidance_scale: float,
269
+ guidance_interval_min: float,
270
+ guidance_interval_max: float,
271
+ ) -> torch.Tensor:
272
+ t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype)
273
+ if do_classifier_free_guidance:
274
+ z_in = torch.cat([z_value, z_value], dim=0)
275
+ labels = torch.cat([class_labels, class_null], dim=0)
276
+ else:
277
+ z_in = z_value
278
+ labels = class_labels
279
+
280
+ t_batch = t.flatten().expand(z_in.shape[0])
281
+ x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample
282
+ v = self.scheduler.velocity_from_prediction(z_in, x_pred, t)
283
+
284
+ if not do_classifier_free_guidance:
285
+ return v
286
+
287
+ v_cond, v_uncond = v.chunk(2, dim=0)
288
+ interval_mask = t < guidance_interval_max
289
+ if guidance_interval_min != 0.0:
290
+ interval_mask = interval_mask & (t > guidance_interval_min)
291
+ scale = torch.where(
292
+ interval_mask,
293
+ torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype),
294
+ torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype),
295
+ )
296
+ return v_uncond + scale * (v_cond - v_uncond)
297
+
298
+ def _run_sampler(
299
+ self,
300
+ latents: torch.Tensor,
301
+ class_labels: torch.Tensor,
302
+ class_null: torch.Tensor,
303
+ num_inference_steps: int,
304
+ do_classifier_free_guidance: bool,
305
+ guidance_scale: float,
306
+ guidance_interval_min: float,
307
+ guidance_interval_max: float,
308
+ sampling_method: str,
309
+ ) -> torch.Tensor:
310
+ device = latents.device
311
+ self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method)
312
+ timesteps = self.scheduler.timesteps
313
+
314
+ for i in self.progress_bar(range(num_inference_steps - 1)):
315
+ t = timesteps[i]
316
+ t_next = timesteps[i + 1]
317
+ v = self._predict_velocity(
318
+ latents,
319
+ t,
320
+ class_labels,
321
+ class_null,
322
+ do_classifier_free_guidance,
323
+ guidance_scale,
324
+ guidance_interval_min,
325
+ guidance_interval_max,
326
+ )
327
+
328
+ if sampling_method == "heun":
329
+ latents_euler = latents + (t_next - t) * v
330
+ v_next = self._predict_velocity(
331
+ latents_euler,
332
+ t_next,
333
+ class_labels,
334
+ class_null,
335
+ do_classifier_free_guidance,
336
+ guidance_scale,
337
+ guidance_interval_min,
338
+ guidance_interval_max,
339
+ )
340
+ latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample
341
+ else:
342
+ latents = self.scheduler.step(v, t, latents).prev_sample
343
+
344
+ t = timesteps[-2]
345
+ t_next = timesteps[-1]
346
+ v = self._predict_velocity(
347
+ latents,
348
+ t,
349
+ class_labels,
350
+ class_null,
351
+ do_classifier_free_guidance,
352
+ guidance_scale,
353
+ guidance_interval_min,
354
+ guidance_interval_max,
355
+ )
356
+ return latents + (t_next - t) * v
357
+
358
+ @torch.inference_mode()
359
+ def __call__(
360
+ self,
361
+ class_labels: Union[int, str, List[Union[int, str]]],
362
+ guidance_scale: Optional[float] = None,
363
+ guidance_interval_min: float = 0.1,
364
+ guidance_interval_max: float = 1.0,
365
+ noise_scale: Optional[float] = None,
366
+ t_eps: Optional[float] = None,
367
+ sampling_method: Optional[str] = None,
368
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
369
+ num_inference_steps: int = 50,
370
+ output_type: Optional[str] = "pil",
371
+ return_dict: bool = True,
372
+ ) -> Union[ImagePipelineOutput, Tuple]:
373
+ r"""
374
+ Generate class-conditional images.
375
+
376
+ Args:
377
+ class_labels (`int`, `str`, `list[int]`, or `list[str]`):
378
+ ImageNet class indices or human-readable label strings (English or Chinese).
379
+ guidance_scale (`float`, *optional*):
380
+ Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
381
+ guidance_interval_min (`float`, defaults to `0.1`):
382
+ Lower bound of the CFG interval in flow time `t in [0, 1]`.
383
+ guidance_interval_max (`float`, defaults to `1.0`):
384
+ Upper bound of the CFG interval in flow time.
385
+ noise_scale (`float`, *optional*):
386
+ Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
387
+ t_eps (`float`, *optional*):
388
+ Epsilon clamp for the `1 - t` denominator (scheduler config by default).
389
+ sampling_method (`str`, *optional*):
390
+ `"heun"` or `"euler"`. Defaults to the scheduler config (`heun`).
391
+ generator (`torch.Generator`, *optional*):
392
+ RNG for reproducibility.
393
+ num_inference_steps (`int`, defaults to `50`):
394
+ Number of solver steps (at least 2).
395
+ output_type (`str`, *optional*, defaults to `"pil"`):
396
+ `"pil"`, `"np"`, or `"pt"`.
397
+ return_dict (`bool`, *optional*, defaults to `True`):
398
+ Return [`ImagePipelineOutput`] if True.
399
+ """
400
+ solver = sampling_method or self.scheduler.config.solver
401
+ if solver not in {"heun", "euler"}:
402
+ raise ValueError("sampling_method must be one of: 'heun', 'euler'.")
403
+ if num_inference_steps < 2:
404
+ raise ValueError("num_inference_steps must be >= 2.")
405
+
406
+ if t_eps is not None:
407
+ self.scheduler.register_to_config(t_eps=t_eps)
408
+
409
+ class_label_ids = self._normalize_class_labels(class_labels)
410
+ do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
411
+
412
+ batch_size = len(class_label_ids)
413
+ image_size = int(self.transformer.config.sample_size)
414
+ channels = int(self.transformer.config.in_channels)
415
+ null_class_val = int(self.transformer.config.num_classes)
416
+
417
+ if guidance_scale is None:
418
+ guidance_scale = 1.0
419
+ if noise_scale is None:
420
+ noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0)
421
+
422
+ latents = (
423
+ randn_tensor(
424
+ shape=(batch_size, channels, image_size, image_size),
425
+ generator=generator,
426
+ device=self._execution_device,
427
+ dtype=self.transformer.dtype,
428
+ )
429
+ * noise_scale
430
+ )
431
+
432
+ class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
433
+ class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
434
+ class_null = torch.full_like(class_labels_t, null_class_val)
435
+
436
+ latents = self._run_sampler(
437
+ latents,
438
+ class_labels_t,
439
+ class_null,
440
+ num_inference_steps,
441
+ do_classifier_free_guidance,
442
+ guidance_scale,
443
+ guidance_interval_min,
444
+ guidance_interval_max,
445
+ solver,
446
+ )
447
+
448
+ images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
449
+ if output_type == "pt":
450
+ images = images_pt
451
+ elif output_type == "np":
452
+ images = images_pt.permute(0, 2, 3, 1).numpy()
453
+ else:
454
+ images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy())
455
+
456
+ self.maybe_free_model_hooks()
457
+
458
+ if not return_dict:
459
+ return (images,)
460
+ return ImagePipelineOutput(images=images)
JiT-H-16/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "JiTScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "t_eps": 0.05,
6
+ "solver": "heun"
7
+ }
JiT-H-16/scheduler/scheduling_jit.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
22
+ from diffusers.utils import BaseOutput
23
+
24
+
25
+ @dataclass
26
+ class JiTSchedulerOutput(BaseOutput):
27
+ """
28
+ Output class for the JiT scheduler's `step` function.
29
+
30
+ Args:
31
+ prev_sample (`torch.Tensor`):
32
+ Updated sample after one solver step along the JiT flow-time grid.
33
+ """
34
+
35
+ prev_sample: torch.Tensor
36
+
37
+
38
+ class JiTScheduler(SchedulerMixin, ConfigMixin):
39
+ """
40
+ Manual flow-matching scheduler for JiT checkpoints.
41
+
42
+ Uses a linear flow-time grid `t in [0, 1]` (increasing), matching the official JiT
43
+ sampler. Velocity is `v = (x_pred - z) / (1 - t)`; integration is explicit Euler or
44
+ Heun along that grid.
45
+ """
46
+
47
+ order = 2
48
+
49
+ @register_to_config
50
+ def __init__(
51
+ self,
52
+ num_train_timesteps: int = 1000,
53
+ t_eps: float = 5e-2,
54
+ solver: str = "heun",
55
+ ):
56
+ if solver not in {"heun", "euler"}:
57
+ raise ValueError("solver must be one of: 'heun', 'euler'.")
58
+ self.timesteps: Optional[torch.Tensor] = None
59
+ self.sigmas: Optional[List[float]] = None
60
+ self.num_inference_steps: Optional[int] = None
61
+ self._step_index: Optional[int] = None
62
+
63
+ @property
64
+ def init_noise_sigma(self) -> float:
65
+ return 1.0
66
+
67
+ def set_timesteps(
68
+ self,
69
+ num_inference_steps: int,
70
+ device: Union[str, torch.device, None] = None,
71
+ solver: Optional[str] = None,
72
+ ) -> None:
73
+ if num_inference_steps < 2:
74
+ raise ValueError("num_inference_steps must be >= 2.")
75
+
76
+ self.num_inference_steps = num_inference_steps
77
+ self.timesteps = torch.linspace(
78
+ 0.0,
79
+ 1.0,
80
+ num_inference_steps + 1,
81
+ device=device,
82
+ dtype=torch.float32,
83
+ )
84
+ sigma_grid = torch.linspace(0.0, 1.0, num_inference_steps, device=device, dtype=torch.float32)
85
+ self.sigmas = (1.0 - sigma_grid).tolist()
86
+ self._step_index = 0
87
+ if solver is not None:
88
+ self.register_to_config(solver=solver)
89
+
90
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
91
+ del timestep
92
+ return sample
93
+
94
+ def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int:
95
+ if self._step_index is not None:
96
+ return self._step_index
97
+ if self.timesteps is None:
98
+ raise ValueError("Call `set_timesteps` before `step`.")
99
+ if timestep is None:
100
+ return 0
101
+ t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0])
102
+ matches = (self.timesteps - t_value).abs() < 1e-6
103
+ if matches.any():
104
+ return int(matches.nonzero(as_tuple=False)[0].item())
105
+ return 0
106
+
107
+ def step(
108
+ self,
109
+ model_output: torch.Tensor,
110
+ timestep: Union[float, torch.Tensor, None],
111
+ sample: torch.Tensor,
112
+ model_output_next: Optional[torch.Tensor] = None,
113
+ return_dict: bool = True,
114
+ ) -> Union[JiTSchedulerOutput, Tuple[torch.Tensor]]:
115
+ """
116
+ Integrate one step on the linear `t` grid.
117
+
118
+ Args:
119
+ model_output (`torch.Tensor`):
120
+ Velocity `v = (x_pred - z) / (1 - t)` at the current time.
121
+ timestep (`float` or `torch.Tensor`, *optional*):
122
+ Current flow time `t`. When omitted, uses the internal step index.
123
+ sample (`torch.Tensor`):
124
+ Current noisy latent `z`.
125
+ model_output_next (`torch.Tensor`, *optional*):
126
+ Velocity at `t_next` (required for Heun intermediate steps).
127
+ """
128
+ if self.timesteps is None:
129
+ raise ValueError("Call `set_timesteps` before `step`.")
130
+
131
+ step_index = self._resolve_step_index(timestep)
132
+ if step_index >= len(self.timesteps) - 1:
133
+ raise ValueError("Scheduler has already reached the final timestep.")
134
+
135
+ t = self.timesteps[step_index]
136
+ t_next = self.timesteps[step_index + 1]
137
+ dt = t_next - t
138
+
139
+ if self.config.solver == "heun" and model_output_next is not None:
140
+ prev_sample = sample + dt * 0.5 * (model_output + model_output_next)
141
+ else:
142
+ prev_sample = sample + dt * model_output
143
+
144
+ self._step_index = step_index + 1
145
+
146
+ if not return_dict:
147
+ return (prev_sample,)
148
+ return JiTSchedulerOutput(prev_sample=prev_sample)
149
+
150
+ def velocity_from_prediction(
151
+ self,
152
+ sample: torch.Tensor,
153
+ x_pred: torch.Tensor,
154
+ timestep: Union[float, torch.Tensor],
155
+ ) -> torch.Tensor:
156
+ """Compute JiT velocity `v = (x_pred - z) / (1 - t)` with denominator clamp."""
157
+ t = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype)
158
+ while t.ndim < sample.ndim:
159
+ t = t.unsqueeze(-1)
160
+ denom = (1.0 - t).clamp_min(self.config.t_eps)
161
+ return (x_pred - sample) / denom
JiT-H-16/transformer/config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "JiTTransformer2DModel",
3
+ "_diffusers_version": "0.36.0",
4
+ "attention_dropout": 0.0,
5
+ "bottleneck_dim": 256,
6
+ "dropout": 0.2,
7
+ "hidden_size": 1280,
8
+ "in_channels": 3,
9
+ "in_context_len": 32,
10
+ "in_context_start": 10,
11
+ "mlp_ratio": 4.0,
12
+ "norm_eps": 1e-06,
13
+ "num_attention_heads": 16,
14
+ "num_classes": 1000,
15
+ "num_layers": 32,
16
+ "patch_size": 16,
17
+ "sample_size": 256
18
+ }
JiT-H-16/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6ad4cf51f5ff385db58573a23353b50df4be7a63dd50bdc7b57af404e7b68e7
3
+ size 3811413928
JiT-H-16/transformer/jit_transformer_2d.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
24
+ from diffusers.models.modeling_utils import ModelMixin
25
+ from diffusers.models.normalization import RMSNorm
26
+ from diffusers.utils import logging
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+
32
+ def broadcat(tensors, dim=-1):
33
+ num_tensors = len(tensors)
34
+ shape_lens = {len(t.shape) for t in tensors}
35
+ if len(shape_lens) != 1:
36
+ raise ValueError("tensors must all have the same number of dimensions")
37
+ shape_len = list(shape_lens)[0]
38
+ dim = (dim + shape_len) if dim < 0 else dim
39
+ dims = list(zip(*(list(t.shape) for t in tensors)))
40
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
41
+
42
+ if not all(len(set(t[1])) <= 2 for t in expandable_dims):
43
+ raise ValueError("invalid dimensions for broadcastable concatenation")
44
+
45
+ max_dims = [(t[0], max(t[1])) for t in expandable_dims]
46
+ expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
47
+ expanded_dims.insert(dim, (dim, dims[dim]))
48
+ expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
49
+ tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
50
+ return torch.cat(tensors, dim=dim)
51
+
52
+
53
+ def rotate_half(x):
54
+ x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2)
55
+ x1, x2 = x.unbind(dim=-1)
56
+ x = torch.stack((-x2, x1), dim=-1)
57
+ return x.view(*x.shape[:-2], -1)
58
+
59
+
60
+ class JiTRotaryEmbedding(nn.Module):
61
+ def __init__(
62
+ self,
63
+ dim,
64
+ pt_seq_len=16,
65
+ ft_seq_len=None,
66
+ custom_freqs=None,
67
+ theta=10000,
68
+ num_cls_token=0,
69
+ ):
70
+ super().__init__()
71
+ if custom_freqs is not None:
72
+ freqs = custom_freqs
73
+ else:
74
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
75
+
76
+ if ft_seq_len is None:
77
+ ft_seq_len = pt_seq_len
78
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
79
+
80
+ freqs = torch.einsum("..., f -> ... f", t, freqs)
81
+ freqs = freqs.repeat_interleave(2, dim=-1)
82
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
83
+
84
+ if num_cls_token > 0:
85
+ freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D]
86
+ cos_img = freqs_flat.cos()
87
+ sin_img = freqs_flat.sin()
88
+
89
+ # prepend in-context cls token
90
+ _, D = cos_img.shape
91
+ cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype)
92
+ sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype)
93
+
94
+ self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False)
95
+ self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False)
96
+ else:
97
+ self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False)
98
+ self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False)
99
+
100
+ def forward(self, t):
101
+ # Applied on (batch, seq_len, heads, head_dim) tensors from attention.
102
+ seq_len = t.shape[1]
103
+ freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
104
+ freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
105
+
106
+ return t * freqs_cos[:, None, :] + rotate_half(t) * freqs_sin[:, None, :]
107
+
108
+
109
+ def modulate(x, shift, scale):
110
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
111
+
112
+
113
+ class JiTPatchEmbed(nn.Module):
114
+ """Image to Patch Embedding with Bottleneck"""
115
+
116
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True):
117
+ super().__init__()
118
+ img_size = (img_size, img_size)
119
+ patch_size = (patch_size, patch_size)
120
+ self.img_size = img_size
121
+ self.patch_size = patch_size
122
+ self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
123
+
124
+ self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False)
125
+ self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias)
126
+
127
+ def forward(self, x):
128
+ x = self.proj2(self.proj1(x)).flatten(2).transpose(1, 2)
129
+ return x
130
+
131
+
132
+ class JiTTimestepEmbedder(nn.Module):
133
+ """
134
+ Embeds scalar timesteps into vector representations.
135
+ """
136
+
137
+ def __init__(self, hidden_size, frequency_embedding_size=256):
138
+ super().__init__()
139
+ self.mlp = nn.Sequential(
140
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
141
+ nn.SiLU(),
142
+ nn.Linear(hidden_size, hidden_size, bias=True),
143
+ )
144
+ self.frequency_embedding_size = frequency_embedding_size
145
+
146
+ @staticmethod
147
+ def timestep_embedding(t, dim, max_period=10000):
148
+ """
149
+ Create sinusoidal timestep embeddings.
150
+ """
151
+ half = dim // 2
152
+ freqs = torch.exp(
153
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
154
+ ).to(device=t.device)
155
+ args = t[:, None].float() * freqs[None]
156
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
157
+ if dim % 2:
158
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
159
+ return embedding
160
+
161
+ def forward(self, t, dtype=None):
162
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
163
+ if dtype is not None:
164
+ t_freq = t_freq.to(dtype=dtype)
165
+ t_emb = self.mlp(t_freq)
166
+ return t_emb
167
+
168
+
169
+ class JiTLabelEmbedder(nn.Module):
170
+ """
171
+ Embeds class labels into vector representations.
172
+ """
173
+
174
+ def __init__(self, num_classes, hidden_size):
175
+ super().__init__()
176
+ self.embedding_table = nn.Embedding(num_classes + 1, hidden_size)
177
+ self.num_classes = num_classes
178
+
179
+ def forward(self, labels):
180
+ embeddings = self.embedding_table(labels)
181
+ return embeddings
182
+
183
+
184
+ class JiTAttention(nn.Module):
185
+ def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
186
+ super().__init__()
187
+ self.num_heads = num_heads
188
+ head_dim = dim // num_heads
189
+
190
+ self.q_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
191
+ self.k_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
192
+
193
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
194
+ self.attn_drop = attn_drop
195
+ self.proj = nn.Linear(dim, dim)
196
+ self.proj_drop = nn.Dropout(proj_drop)
197
+
198
+ def forward(self, x, rope=None):
199
+ B, N, C = x.shape
200
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
201
+ q, k, v = qkv[0], qkv[1], qkv[2]
202
+
203
+ q = self.q_norm(q)
204
+ k = self.k_norm(k)
205
+
206
+ if rope is not None:
207
+ q = q.transpose(1, 2)
208
+ k = k.transpose(1, 2)
209
+ q = rope(q)
210
+ k = rope(k)
211
+ q = q.transpose(1, 2)
212
+ k = k.transpose(1, 2)
213
+
214
+ dropout_p = self.attn_drop if self.training else 0.0
215
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
216
+ x = x.transpose(1, 2).reshape(B, N, C)
217
+ x = self.proj(x)
218
+ x = self.proj_drop(x)
219
+ return x
220
+
221
+
222
+ class JiTSwiGLUFFN(nn.Module):
223
+ def __init__(self, dim: int, hidden_dim: int, drop=0.0, bias=True) -> None:
224
+ super().__init__()
225
+ hidden_dim = int(hidden_dim * 2 / 3)
226
+ self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias)
227
+ self.w3 = nn.Linear(hidden_dim, dim, bias=bias)
228
+ self.ffn_dropout = nn.Dropout(drop)
229
+
230
+ def forward(self, x):
231
+ x12 = self.w12(x)
232
+ x1, x2 = x12.chunk(2, dim=-1)
233
+ hidden = F.silu(x1) * x2
234
+ return self.w3(self.ffn_dropout(hidden))
235
+
236
+
237
+ class JiTBlock(nn.Module):
238
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
239
+ super().__init__()
240
+ self.norm1 = RMSNorm(hidden_size, eps=eps)
241
+ self.attn = JiTAttention(
242
+ hidden_size,
243
+ num_heads=num_heads,
244
+ qkv_bias=True,
245
+ qk_norm=True,
246
+ attn_drop=attn_drop,
247
+ proj_drop=proj_drop,
248
+ eps=eps,
249
+ )
250
+ self.norm2 = RMSNorm(hidden_size, eps=eps)
251
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
252
+ self.mlp = JiTSwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop)
253
+
254
+ self.act = nn.SiLU()
255
+ self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
256
+
257
+ def forward(self, x, c, feat_rope=None):
258
+ # Apply activation
259
+ c = self.act(c)
260
+
261
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
262
+
263
+ # Attention block
264
+ norm_x = self.norm1(x)
265
+ modulated_x = modulate(norm_x, shift_msa, scale_msa)
266
+ attn_out = self.attn(modulated_x, rope=feat_rope)
267
+ x = x + gate_msa.unsqueeze(1) * attn_out
268
+
269
+ # MLP block
270
+ norm_x = self.norm2(x)
271
+ modulated_x = modulate(norm_x, shift_mlp, scale_mlp)
272
+ mlp_out = self.mlp(modulated_x)
273
+ x = x + gate_mlp.unsqueeze(1) * mlp_out
274
+
275
+ return x
276
+
277
+
278
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
279
+ grid_h = np.arange(grid_size, dtype=np.float32)
280
+ grid_w = np.arange(grid_size, dtype=np.float32)
281
+ grid = np.meshgrid(grid_w, grid_h)
282
+ grid = np.stack(grid, axis=0)
283
+ grid = grid.reshape([2, 1, grid_size, grid_size])
284
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
285
+ if cls_token and extra_tokens > 0:
286
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
287
+ return pos_embed
288
+
289
+
290
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
291
+ if embed_dim % 2 != 0:
292
+ raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
293
+
294
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
295
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
296
+ emb = np.concatenate([emb_h, emb_w], axis=1)
297
+ return emb
298
+
299
+
300
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
301
+ if embed_dim % 2 != 0:
302
+ raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
303
+
304
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
305
+ omega /= embed_dim / 2.0
306
+ omega = 1.0 / 10000**omega
307
+
308
+ pos = pos.reshape(-1)
309
+ out = np.einsum("m,d->md", pos, omega)
310
+
311
+ emb_sin = np.sin(out)
312
+ emb_cos = np.cos(out)
313
+
314
+ emb = np.concatenate([emb_sin, emb_cos], axis=1)
315
+ return emb
316
+
317
+
318
+ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
319
+ r"""
320
+ A 2D Transformer for pixel-space class-conditional generation with JiT
321
+ ([Back to Basics: Let Denoising Generative Models Denoise](https://arxiv.org/abs/2511.13720)).
322
+
323
+ Parameters:
324
+ sample_size (`int`, defaults to `256`):
325
+ Input image resolution (height and width).
326
+ patch_size (`int`, defaults to `16`):
327
+ Patch size for the bottleneck patch embedder.
328
+ in_channels (`int`, defaults to `3`):
329
+ Number of input image channels.
330
+ hidden_size (`int`, defaults to `768`):
331
+ Transformer hidden dimension.
332
+ num_layers (`int`, defaults to `12`):
333
+ Number of JiT transformer blocks.
334
+ num_attention_heads (`int`, defaults to `12`):
335
+ Number of attention heads per block.
336
+ mlp_ratio (`float`, defaults to `4.0`):
337
+ MLP hidden dimension multiplier.
338
+ attention_dropout (`float`, defaults to `0.0`):
339
+ Attention dropout in the middle quarter of blocks.
340
+ dropout (`float`, defaults to `0.0`):
341
+ Projection dropout in the middle quarter of blocks.
342
+ num_classes (`int`, defaults to `1000`):
343
+ Number of class labels (null label uses index `num_classes` for CFG).
344
+ bottleneck_dim (`int`, defaults to `128`):
345
+ PCA bottleneck dimension in the patch embedder.
346
+ in_context_len (`int`, defaults to `32`):
347
+ Number of in-context class tokens prepended mid-network.
348
+ in_context_start (`int`, defaults to `4`):
349
+ Block index at which in-context tokens are inserted.
350
+ norm_eps (`float`, defaults to `1e-6`):
351
+ Epsilon for RMSNorm layers.
352
+ """
353
+
354
+ _supports_gradient_checkpointing = True
355
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
356
+
357
+ @register_to_config
358
+ def __init__(
359
+ self,
360
+ sample_size: int = 256,
361
+ patch_size: int = 16,
362
+ in_channels: int = 3,
363
+ hidden_size: int = 768,
364
+ num_layers: int = 12,
365
+ num_attention_heads: int = 12,
366
+ mlp_ratio: float = 4.0,
367
+ attention_dropout: float = 0.0,
368
+ dropout: float = 0.0,
369
+ num_classes: int = 1000,
370
+ bottleneck_dim: int = 128,
371
+ in_context_len: int = 32,
372
+ in_context_start: int = 4,
373
+ norm_eps: float = 1e-6,
374
+ ):
375
+ super().__init__()
376
+ self.sample_size = sample_size
377
+ self.patch_size = patch_size
378
+ self.in_channels = in_channels
379
+ self.out_channels = in_channels
380
+ self.hidden_size = hidden_size
381
+ self.num_layers = num_layers
382
+ self.num_attention_heads = num_attention_heads
383
+ self.in_context_len = in_context_len
384
+ self.in_context_start = in_context_start
385
+ self.norm_eps = norm_eps
386
+ self.gradient_checkpointing = False
387
+
388
+ # Time and Class Embedding
389
+ self.t_embedder = JiTTimestepEmbedder(hidden_size)
390
+ self.y_embedder = JiTLabelEmbedder(num_classes, hidden_size)
391
+
392
+ # Patch Embedding
393
+ self.x_embedder = JiTPatchEmbed(
394
+ img_size=sample_size,
395
+ patch_size=patch_size,
396
+ in_chans=in_channels,
397
+ pca_dim=bottleneck_dim,
398
+ embed_dim=hidden_size,
399
+ bias=True,
400
+ )
401
+
402
+ # Positional Embedding (Fixed Sin-Cos)
403
+ num_patches = self.x_embedder.num_patches
404
+ pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches**0.5))
405
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
406
+
407
+ # In-context Embedding
408
+ if self.in_context_len > 0:
409
+ self.in_context_posemb = nn.Parameter(torch.zeros(1, self.in_context_len, hidden_size))
410
+
411
+ # RoPE
412
+ half_head_dim = hidden_size // num_attention_heads // 2
413
+ hw_seq_len = sample_size // patch_size
414
+ self.feat_rope = JiTRotaryEmbedding(dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=0)
415
+ self.feat_rope_incontext = JiTRotaryEmbedding(
416
+ dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=self.in_context_len
417
+ )
418
+
419
+ # Blocks
420
+ self.blocks = nn.ModuleList(
421
+ [
422
+ JiTBlock(
423
+ hidden_size,
424
+ num_attention_heads,
425
+ mlp_ratio=mlp_ratio,
426
+ attn_drop=attention_dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
427
+ proj_drop=dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
428
+ eps=norm_eps,
429
+ )
430
+ for i in range(num_layers)
431
+ ]
432
+ )
433
+
434
+ # Final Layer
435
+ self.norm_final = RMSNorm(hidden_size, eps=norm_eps)
436
+ self.linear_final = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
437
+ self.act_final = nn.SiLU()
438
+ self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
439
+
440
+ def forward(
441
+ self,
442
+ hidden_states: torch.Tensor,
443
+ timestep: torch.LongTensor,
444
+ class_labels: torch.LongTensor,
445
+ return_dict: bool = True,
446
+ ):
447
+
448
+ t_emb = self.t_embedder(timestep, dtype=hidden_states.dtype)
449
+ y_emb = self.y_embedder(class_labels)
450
+
451
+ # Ensure embeddings match hidden_states dtype
452
+ y_emb = y_emb.to(dtype=hidden_states.dtype)
453
+
454
+ c = t_emb + y_emb
455
+
456
+ # Patch Embed
457
+ x = self.x_embedder(hidden_states)
458
+ x = x + self.pos_embed.to(x.dtype)
459
+
460
+ # Blocks
461
+ for i, block in enumerate(self.blocks):
462
+ if self.in_context_len > 0 and i == self.in_context_start:
463
+ in_context_tokens = y_emb.unsqueeze(1).repeat(1, self.in_context_len, 1)
464
+ in_context_tokens = in_context_tokens + self.in_context_posemb.to(in_context_tokens.dtype)
465
+ x = torch.cat([in_context_tokens, x], dim=1)
466
+
467
+ rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
468
+
469
+ if self.training and self.gradient_checkpointing:
470
+ x = torch.utils.checkpoint.checkpoint(
471
+ block,
472
+ x,
473
+ c,
474
+ rope,
475
+ use_reentrant=False,
476
+ )
477
+ else:
478
+ x = block(x, c, feat_rope=rope)
479
+
480
+ # Slice off in-context tokens
481
+ if self.in_context_len > 0:
482
+ x = x[:, self.in_context_len :]
483
+
484
+ # Final Layer
485
+ c = self.act_final(c)
486
+ shift, scale = self.adaLN_modulation_final(c).chunk(2, dim=1)
487
+
488
+ x = modulate(self.norm_final(x), shift, scale)
489
+ x = self.linear_final(x)
490
+
491
+ # Unpatchify
492
+ h = w = int(x.shape[1] ** 0.5)
493
+ x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels))
494
+ x = torch.einsum("nhwpqc->nchpwq", x)
495
+ output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size))
496
+
497
+ if not return_dict:
498
+ return (output,)
499
+
500
+ return Transformer2DModelOutput(sample=output)
JiT-H-32/model_index.json CHANGED
@@ -1,8 +1,15 @@
1
  {
2
- "_class_name": "JiTPipeline",
 
 
 
3
  "_diffusers_version": "0.36.0",
 
 
 
 
4
  "transformer": [
5
- "jit_diffusers",
6
  "JiTTransformer2DModel"
7
  ]
8
- }
 
1
  {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "JiTPipeline"
5
+ ],
6
  "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "scheduling_jit",
9
+ "JiTScheduler"
10
+ ],
11
  "transformer": [
12
+ "jit_transformer_2d",
13
  "JiTTransformer2DModel"
14
  ]
15
+ }
JiT-H-32/pipeline.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import importlib
18
+ import json
19
+ import sys
20
+ from pathlib import Path
21
+ from typing import Dict, List, Optional, Tuple, Union
22
+
23
+ import torch
24
+
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
26
+ from diffusers.utils.torch_utils import randn_tensor
27
+
28
+
29
+ RECOMMENDED_NOISE_BY_SIZE = {
30
+ 256: 1.0,
31
+ 512: 2.0,
32
+ }
33
+
34
+
35
+ class JiTPipeline(DiffusionPipeline):
36
+ r"""
37
+ Pipeline for image generation using JiT (Just image Transformer).
38
+
39
+ Parameters:
40
+ transformer ([`JiTTransformer2DModel`]):
41
+ A class-conditioned `JiTTransformer2DModel` to denoise the images.
42
+ scheduler ([`JiTScheduler`]):
43
+ Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler).
44
+ id2label (`dict[int, str]`, *optional*):
45
+ ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
46
+ id2label_cn (`dict[int, str]`, *optional*):
47
+ ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms.
48
+ """
49
+
50
+ model_cpu_offload_seq = "transformer"
51
+
52
+ @classmethod
53
+ def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
54
+ """Load a self-contained variant folder locally or from the Hub.
55
+
56
+ Examples:
57
+ JiTPipeline.from_pretrained(".")
58
+ JiTPipeline.from_pretrained("./JiT-H-32")
59
+ DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
60
+ """
61
+ repo_root = Path(__file__).resolve().parent
62
+
63
+ if pretrained_model_name_or_path in (None, "", "."):
64
+ variant = repo_root
65
+ elif (
66
+ isinstance(pretrained_model_name_or_path, str)
67
+ and "/" in pretrained_model_name_or_path
68
+ and not Path(pretrained_model_name_or_path).exists()
69
+ ):
70
+ from huggingface_hub import snapshot_download
71
+
72
+ hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
73
+ if subfolder:
74
+ hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"])
75
+ cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
76
+ variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
77
+ else:
78
+ variant = Path(pretrained_model_name_or_path)
79
+ if not variant.is_absolute():
80
+ candidate = (Path.cwd() / variant).resolve()
81
+ variant = candidate if candidate.exists() else (repo_root / variant).resolve()
82
+ if subfolder:
83
+ variant = variant / subfolder
84
+
85
+ model_kwargs = dict(kwargs)
86
+ inserted: List[str] = []
87
+
88
+ def _load_component(folder: str, module_name: str, class_name: str):
89
+ comp_dir = variant / folder
90
+ module_path = comp_dir / f"{module_name}.py"
91
+ has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
92
+ if not module_path.exists() or not has_weights:
93
+ return None
94
+
95
+ comp_path = str(comp_dir)
96
+ if comp_path not in sys.path:
97
+ sys.path.insert(0, comp_path)
98
+ inserted.append(comp_path)
99
+
100
+ module = importlib.import_module(module_name)
101
+ component_cls = getattr(module, class_name)
102
+ return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
103
+
104
+ try:
105
+ transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
106
+ scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler")
107
+
108
+ if transformer is None:
109
+ raise ValueError(f"No loadable transformer found under {variant}")
110
+
111
+ variant_path = str(variant)
112
+ id2label, id2label_cn = cls._load_labels_for_variant(variant_path)
113
+
114
+ pipe = cls(
115
+ transformer=transformer,
116
+ scheduler=scheduler,
117
+ id2label=id2label,
118
+ id2label_cn=id2label_cn,
119
+ )
120
+ if variant_path and hasattr(pipe, "register_to_config"):
121
+ pipe.register_to_config(_name_or_path=variant_path)
122
+ return pipe
123
+ finally:
124
+ for comp_path in inserted:
125
+ if comp_path in sys.path:
126
+ sys.path.remove(comp_path)
127
+
128
+ def __init__(
129
+ self,
130
+ transformer,
131
+ scheduler,
132
+ id2label: Optional[Dict[int, str]] = None,
133
+ id2label_cn: Optional[Dict[int, str]] = None,
134
+ ):
135
+ super().__init__()
136
+ self.register_modules(transformer=transformer, scheduler=scheduler)
137
+
138
+ self._id2label = id2label or {}
139
+ self._id2label_cn = id2label_cn or {}
140
+ self.labels = self._build_label2id(self._id2label)
141
+ self.labels_cn = self._build_label2id(self._id2label_cn)
142
+
143
+ def _ensure_labels_loaded(self) -> None:
144
+ if self._id2label or self._id2label_cn:
145
+ return
146
+ loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None))
147
+ if loaded_en:
148
+ self._id2label = loaded_en
149
+ self.labels = self._build_label2id(self._id2label)
150
+ if loaded_cn:
151
+ self._id2label_cn = loaded_cn
152
+ self.labels_cn = self._build_label2id(self._id2label_cn)
153
+
154
+ @staticmethod
155
+ def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]:
156
+ if not variant_path:
157
+ return None
158
+ variant_dir = Path(variant_path).resolve()
159
+ labels_dir = variant_dir.parent / "labels"
160
+ return labels_dir if labels_dir.is_dir() else None
161
+
162
+ @staticmethod
163
+ def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]:
164
+ filename = "id2label_en.json" if lang == "en" else "id2label_cn.json"
165
+ path = labels_dir / filename
166
+ if not path.exists():
167
+ raise FileNotFoundError(path)
168
+ raw = json.loads(path.read_text(encoding="utf-8"))
169
+ return {int(key): value for key, value in raw.items()}
170
+
171
+ @classmethod
172
+ def _load_labels_for_variant(
173
+ cls,
174
+ variant_path: Optional[str],
175
+ ) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]:
176
+ labels_dir = cls._labels_dir_for_variant(variant_path)
177
+ if labels_dir is None:
178
+ return None, None
179
+ try:
180
+ return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn")
181
+ except FileNotFoundError:
182
+ return None, None
183
+
184
+ @staticmethod
185
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
186
+ label2id: Dict[str, int] = {}
187
+ for class_id, value in id2label.items():
188
+ for synonym in value.split(","):
189
+ synonym = synonym.strip()
190
+ if synonym:
191
+ label2id[synonym] = int(class_id)
192
+ return dict(sorted(label2id.items()))
193
+
194
+ @property
195
+ def id2label(self) -> Dict[int, str]:
196
+ """ImageNet class id to English label string (comma-separated synonyms)."""
197
+ self._ensure_labels_loaded()
198
+ return self._id2label
199
+
200
+ @property
201
+ def id2label_cn(self) -> Dict[int, str]:
202
+ """ImageNet class id to Chinese label string (comma-separated synonyms)."""
203
+ self._ensure_labels_loaded()
204
+ return self._id2label_cn
205
+
206
+ def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]:
207
+ r"""
208
+ Map ImageNet label strings to class ids.
209
+
210
+ Args:
211
+ label (`str` or `list[str]`):
212
+ One or more label strings. Each string must match a synonym in `id2label` (English)
213
+ or `id2label_cn` (Chinese).
214
+ lang (`str`, *optional*, defaults to `"en"`):
215
+ `"en"` uses English synonyms; `"cn"` uses Chinese synonyms.
216
+ """
217
+ if lang not in ("en", "cn"):
218
+ raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.")
219
+
220
+ self._ensure_labels_loaded()
221
+ label2id = self.labels if lang == "en" else self.labels_cn
222
+ if not label2id:
223
+ raise ValueError(
224
+ f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
225
+ )
226
+
227
+ if isinstance(label, str):
228
+ label = [label]
229
+
230
+ missing = [item for item in label if item not in label2id]
231
+ if missing:
232
+ preview = ", ".join(list(label2id.keys())[:8])
233
+ raise ValueError(
234
+ f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
235
+ )
236
+ return [label2id[item] for item in label]
237
+
238
+ def _normalize_class_labels(
239
+ self,
240
+ class_labels: Union[int, str, List[Union[int, str]]],
241
+ ) -> List[int]:
242
+ if isinstance(class_labels, int):
243
+ return [class_labels]
244
+
245
+ if isinstance(class_labels, str):
246
+ return self.get_label_ids(class_labels)
247
+
248
+ if class_labels and isinstance(class_labels[0], str):
249
+ self._ensure_labels_loaded()
250
+ if all(label in self.labels for label in class_labels):
251
+ return self.get_label_ids(class_labels, lang="en")
252
+ if all(label in self.labels_cn for label in class_labels):
253
+ return self.get_label_ids(class_labels, lang="cn")
254
+ raise ValueError(
255
+ "Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` "
256
+ "or Chinese synonyms from `pipe.labels_cn`."
257
+ )
258
+
259
+ return list(class_labels)
260
+
261
+ def _predict_velocity(
262
+ self,
263
+ z_value: torch.Tensor,
264
+ t: torch.Tensor,
265
+ class_labels: torch.Tensor,
266
+ class_null: torch.Tensor,
267
+ do_classifier_free_guidance: bool,
268
+ guidance_scale: float,
269
+ guidance_interval_min: float,
270
+ guidance_interval_max: float,
271
+ ) -> torch.Tensor:
272
+ t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype)
273
+ if do_classifier_free_guidance:
274
+ z_in = torch.cat([z_value, z_value], dim=0)
275
+ labels = torch.cat([class_labels, class_null], dim=0)
276
+ else:
277
+ z_in = z_value
278
+ labels = class_labels
279
+
280
+ t_batch = t.flatten().expand(z_in.shape[0])
281
+ x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample
282
+ v = self.scheduler.velocity_from_prediction(z_in, x_pred, t)
283
+
284
+ if not do_classifier_free_guidance:
285
+ return v
286
+
287
+ v_cond, v_uncond = v.chunk(2, dim=0)
288
+ interval_mask = t < guidance_interval_max
289
+ if guidance_interval_min != 0.0:
290
+ interval_mask = interval_mask & (t > guidance_interval_min)
291
+ scale = torch.where(
292
+ interval_mask,
293
+ torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype),
294
+ torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype),
295
+ )
296
+ return v_uncond + scale * (v_cond - v_uncond)
297
+
298
+ def _run_sampler(
299
+ self,
300
+ latents: torch.Tensor,
301
+ class_labels: torch.Tensor,
302
+ class_null: torch.Tensor,
303
+ num_inference_steps: int,
304
+ do_classifier_free_guidance: bool,
305
+ guidance_scale: float,
306
+ guidance_interval_min: float,
307
+ guidance_interval_max: float,
308
+ sampling_method: str,
309
+ ) -> torch.Tensor:
310
+ device = latents.device
311
+ self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method)
312
+ timesteps = self.scheduler.timesteps
313
+
314
+ for i in self.progress_bar(range(num_inference_steps - 1)):
315
+ t = timesteps[i]
316
+ t_next = timesteps[i + 1]
317
+ v = self._predict_velocity(
318
+ latents,
319
+ t,
320
+ class_labels,
321
+ class_null,
322
+ do_classifier_free_guidance,
323
+ guidance_scale,
324
+ guidance_interval_min,
325
+ guidance_interval_max,
326
+ )
327
+
328
+ if sampling_method == "heun":
329
+ latents_euler = latents + (t_next - t) * v
330
+ v_next = self._predict_velocity(
331
+ latents_euler,
332
+ t_next,
333
+ class_labels,
334
+ class_null,
335
+ do_classifier_free_guidance,
336
+ guidance_scale,
337
+ guidance_interval_min,
338
+ guidance_interval_max,
339
+ )
340
+ latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample
341
+ else:
342
+ latents = self.scheduler.step(v, t, latents).prev_sample
343
+
344
+ t = timesteps[-2]
345
+ t_next = timesteps[-1]
346
+ v = self._predict_velocity(
347
+ latents,
348
+ t,
349
+ class_labels,
350
+ class_null,
351
+ do_classifier_free_guidance,
352
+ guidance_scale,
353
+ guidance_interval_min,
354
+ guidance_interval_max,
355
+ )
356
+ return latents + (t_next - t) * v
357
+
358
+ @torch.inference_mode()
359
+ def __call__(
360
+ self,
361
+ class_labels: Union[int, str, List[Union[int, str]]],
362
+ guidance_scale: Optional[float] = None,
363
+ guidance_interval_min: float = 0.1,
364
+ guidance_interval_max: float = 1.0,
365
+ noise_scale: Optional[float] = None,
366
+ t_eps: Optional[float] = None,
367
+ sampling_method: Optional[str] = None,
368
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
369
+ num_inference_steps: int = 50,
370
+ output_type: Optional[str] = "pil",
371
+ return_dict: bool = True,
372
+ ) -> Union[ImagePipelineOutput, Tuple]:
373
+ r"""
374
+ Generate class-conditional images.
375
+
376
+ Args:
377
+ class_labels (`int`, `str`, `list[int]`, or `list[str]`):
378
+ ImageNet class indices or human-readable label strings (English or Chinese).
379
+ guidance_scale (`float`, *optional*):
380
+ Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
381
+ guidance_interval_min (`float`, defaults to `0.1`):
382
+ Lower bound of the CFG interval in flow time `t in [0, 1]`.
383
+ guidance_interval_max (`float`, defaults to `1.0`):
384
+ Upper bound of the CFG interval in flow time.
385
+ noise_scale (`float`, *optional*):
386
+ Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
387
+ t_eps (`float`, *optional*):
388
+ Epsilon clamp for the `1 - t` denominator (scheduler config by default).
389
+ sampling_method (`str`, *optional*):
390
+ `"heun"` or `"euler"`. Defaults to the scheduler config (`heun`).
391
+ generator (`torch.Generator`, *optional*):
392
+ RNG for reproducibility.
393
+ num_inference_steps (`int`, defaults to `50`):
394
+ Number of solver steps (at least 2).
395
+ output_type (`str`, *optional*, defaults to `"pil"`):
396
+ `"pil"`, `"np"`, or `"pt"`.
397
+ return_dict (`bool`, *optional*, defaults to `True`):
398
+ Return [`ImagePipelineOutput`] if True.
399
+ """
400
+ solver = sampling_method or self.scheduler.config.solver
401
+ if solver not in {"heun", "euler"}:
402
+ raise ValueError("sampling_method must be one of: 'heun', 'euler'.")
403
+ if num_inference_steps < 2:
404
+ raise ValueError("num_inference_steps must be >= 2.")
405
+
406
+ if t_eps is not None:
407
+ self.scheduler.register_to_config(t_eps=t_eps)
408
+
409
+ class_label_ids = self._normalize_class_labels(class_labels)
410
+ do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
411
+
412
+ batch_size = len(class_label_ids)
413
+ image_size = int(self.transformer.config.sample_size)
414
+ channels = int(self.transformer.config.in_channels)
415
+ null_class_val = int(self.transformer.config.num_classes)
416
+
417
+ if guidance_scale is None:
418
+ guidance_scale = 1.0
419
+ if noise_scale is None:
420
+ noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0)
421
+
422
+ latents = (
423
+ randn_tensor(
424
+ shape=(batch_size, channels, image_size, image_size),
425
+ generator=generator,
426
+ device=self._execution_device,
427
+ dtype=self.transformer.dtype,
428
+ )
429
+ * noise_scale
430
+ )
431
+
432
+ class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
433
+ class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
434
+ class_null = torch.full_like(class_labels_t, null_class_val)
435
+
436
+ latents = self._run_sampler(
437
+ latents,
438
+ class_labels_t,
439
+ class_null,
440
+ num_inference_steps,
441
+ do_classifier_free_guidance,
442
+ guidance_scale,
443
+ guidance_interval_min,
444
+ guidance_interval_max,
445
+ solver,
446
+ )
447
+
448
+ images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
449
+ if output_type == "pt":
450
+ images = images_pt
451
+ elif output_type == "np":
452
+ images = images_pt.permute(0, 2, 3, 1).numpy()
453
+ else:
454
+ images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy())
455
+
456
+ self.maybe_free_model_hooks()
457
+
458
+ if not return_dict:
459
+ return (images,)
460
+ return ImagePipelineOutput(images=images)
JiT-H-32/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "JiTScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "t_eps": 0.05,
6
+ "solver": "heun"
7
+ }
JiT-H-32/scheduler/scheduling_jit.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
22
+ from diffusers.utils import BaseOutput
23
+
24
+
25
+ @dataclass
26
+ class JiTSchedulerOutput(BaseOutput):
27
+ """
28
+ Output class for the JiT scheduler's `step` function.
29
+
30
+ Args:
31
+ prev_sample (`torch.Tensor`):
32
+ Updated sample after one solver step along the JiT flow-time grid.
33
+ """
34
+
35
+ prev_sample: torch.Tensor
36
+
37
+
38
+ class JiTScheduler(SchedulerMixin, ConfigMixin):
39
+ """
40
+ Manual flow-matching scheduler for JiT checkpoints.
41
+
42
+ Uses a linear flow-time grid `t in [0, 1]` (increasing), matching the official JiT
43
+ sampler. Velocity is `v = (x_pred - z) / (1 - t)`; integration is explicit Euler or
44
+ Heun along that grid.
45
+ """
46
+
47
+ order = 2
48
+
49
+ @register_to_config
50
+ def __init__(
51
+ self,
52
+ num_train_timesteps: int = 1000,
53
+ t_eps: float = 5e-2,
54
+ solver: str = "heun",
55
+ ):
56
+ if solver not in {"heun", "euler"}:
57
+ raise ValueError("solver must be one of: 'heun', 'euler'.")
58
+ self.timesteps: Optional[torch.Tensor] = None
59
+ self.sigmas: Optional[List[float]] = None
60
+ self.num_inference_steps: Optional[int] = None
61
+ self._step_index: Optional[int] = None
62
+
63
+ @property
64
+ def init_noise_sigma(self) -> float:
65
+ return 1.0
66
+
67
+ def set_timesteps(
68
+ self,
69
+ num_inference_steps: int,
70
+ device: Union[str, torch.device, None] = None,
71
+ solver: Optional[str] = None,
72
+ ) -> None:
73
+ if num_inference_steps < 2:
74
+ raise ValueError("num_inference_steps must be >= 2.")
75
+
76
+ self.num_inference_steps = num_inference_steps
77
+ self.timesteps = torch.linspace(
78
+ 0.0,
79
+ 1.0,
80
+ num_inference_steps + 1,
81
+ device=device,
82
+ dtype=torch.float32,
83
+ )
84
+ sigma_grid = torch.linspace(0.0, 1.0, num_inference_steps, device=device, dtype=torch.float32)
85
+ self.sigmas = (1.0 - sigma_grid).tolist()
86
+ self._step_index = 0
87
+ if solver is not None:
88
+ self.register_to_config(solver=solver)
89
+
90
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
91
+ del timestep
92
+ return sample
93
+
94
+ def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int:
95
+ if self._step_index is not None:
96
+ return self._step_index
97
+ if self.timesteps is None:
98
+ raise ValueError("Call `set_timesteps` before `step`.")
99
+ if timestep is None:
100
+ return 0
101
+ t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0])
102
+ matches = (self.timesteps - t_value).abs() < 1e-6
103
+ if matches.any():
104
+ return int(matches.nonzero(as_tuple=False)[0].item())
105
+ return 0
106
+
107
+ def step(
108
+ self,
109
+ model_output: torch.Tensor,
110
+ timestep: Union[float, torch.Tensor, None],
111
+ sample: torch.Tensor,
112
+ model_output_next: Optional[torch.Tensor] = None,
113
+ return_dict: bool = True,
114
+ ) -> Union[JiTSchedulerOutput, Tuple[torch.Tensor]]:
115
+ """
116
+ Integrate one step on the linear `t` grid.
117
+
118
+ Args:
119
+ model_output (`torch.Tensor`):
120
+ Velocity `v = (x_pred - z) / (1 - t)` at the current time.
121
+ timestep (`float` or `torch.Tensor`, *optional*):
122
+ Current flow time `t`. When omitted, uses the internal step index.
123
+ sample (`torch.Tensor`):
124
+ Current noisy latent `z`.
125
+ model_output_next (`torch.Tensor`, *optional*):
126
+ Velocity at `t_next` (required for Heun intermediate steps).
127
+ """
128
+ if self.timesteps is None:
129
+ raise ValueError("Call `set_timesteps` before `step`.")
130
+
131
+ step_index = self._resolve_step_index(timestep)
132
+ if step_index >= len(self.timesteps) - 1:
133
+ raise ValueError("Scheduler has already reached the final timestep.")
134
+
135
+ t = self.timesteps[step_index]
136
+ t_next = self.timesteps[step_index + 1]
137
+ dt = t_next - t
138
+
139
+ if self.config.solver == "heun" and model_output_next is not None:
140
+ prev_sample = sample + dt * 0.5 * (model_output + model_output_next)
141
+ else:
142
+ prev_sample = sample + dt * model_output
143
+
144
+ self._step_index = step_index + 1
145
+
146
+ if not return_dict:
147
+ return (prev_sample,)
148
+ return JiTSchedulerOutput(prev_sample=prev_sample)
149
+
150
+ def velocity_from_prediction(
151
+ self,
152
+ sample: torch.Tensor,
153
+ x_pred: torch.Tensor,
154
+ timestep: Union[float, torch.Tensor],
155
+ ) -> torch.Tensor:
156
+ """Compute JiT velocity `v = (x_pred - z) / (1 - t)` with denominator clamp."""
157
+ t = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype)
158
+ while t.ndim < sample.ndim:
159
+ t = t.unsqueeze(-1)
160
+ denom = (1.0 - t).clamp_min(self.config.t_eps)
161
+ return (x_pred - sample) / denom
JiT-H-32/transformer/config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "JiTTransformer2DModel",
3
+ "_diffusers_version": "0.36.0",
4
+ "attention_dropout": 0.0,
5
+ "bottleneck_dim": 256,
6
+ "dropout": 0.2,
7
+ "hidden_size": 1280,
8
+ "in_channels": 3,
9
+ "in_context_len": 32,
10
+ "in_context_start": 10,
11
+ "mlp_ratio": 4.0,
12
+ "norm_eps": 1e-06,
13
+ "num_attention_heads": 16,
14
+ "num_classes": 1000,
15
+ "num_layers": 32,
16
+ "patch_size": 32,
17
+ "sample_size": 512
18
+ }
JiT-H-32/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:578fc2f9f4ccaa34c3d2f5076811e101419e5dfd1b20dcca89bbfb29f5f60ab6
3
+ size 3825578920
JiT-H-32/transformer/jit_transformer_2d.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
24
+ from diffusers.models.modeling_utils import ModelMixin
25
+ from diffusers.models.normalization import RMSNorm
26
+ from diffusers.utils import logging
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+
32
+ def broadcat(tensors, dim=-1):
33
+ num_tensors = len(tensors)
34
+ shape_lens = {len(t.shape) for t in tensors}
35
+ if len(shape_lens) != 1:
36
+ raise ValueError("tensors must all have the same number of dimensions")
37
+ shape_len = list(shape_lens)[0]
38
+ dim = (dim + shape_len) if dim < 0 else dim
39
+ dims = list(zip(*(list(t.shape) for t in tensors)))
40
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
41
+
42
+ if not all(len(set(t[1])) <= 2 for t in expandable_dims):
43
+ raise ValueError("invalid dimensions for broadcastable concatenation")
44
+
45
+ max_dims = [(t[0], max(t[1])) for t in expandable_dims]
46
+ expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
47
+ expanded_dims.insert(dim, (dim, dims[dim]))
48
+ expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
49
+ tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
50
+ return torch.cat(tensors, dim=dim)
51
+
52
+
53
+ def rotate_half(x):
54
+ x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2)
55
+ x1, x2 = x.unbind(dim=-1)
56
+ x = torch.stack((-x2, x1), dim=-1)
57
+ return x.view(*x.shape[:-2], -1)
58
+
59
+
60
+ class JiTRotaryEmbedding(nn.Module):
61
+ def __init__(
62
+ self,
63
+ dim,
64
+ pt_seq_len=16,
65
+ ft_seq_len=None,
66
+ custom_freqs=None,
67
+ theta=10000,
68
+ num_cls_token=0,
69
+ ):
70
+ super().__init__()
71
+ if custom_freqs is not None:
72
+ freqs = custom_freqs
73
+ else:
74
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
75
+
76
+ if ft_seq_len is None:
77
+ ft_seq_len = pt_seq_len
78
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
79
+
80
+ freqs = torch.einsum("..., f -> ... f", t, freqs)
81
+ freqs = freqs.repeat_interleave(2, dim=-1)
82
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
83
+
84
+ if num_cls_token > 0:
85
+ freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D]
86
+ cos_img = freqs_flat.cos()
87
+ sin_img = freqs_flat.sin()
88
+
89
+ # prepend in-context cls token
90
+ _, D = cos_img.shape
91
+ cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype)
92
+ sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype)
93
+
94
+ self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False)
95
+ self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False)
96
+ else:
97
+ self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False)
98
+ self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False)
99
+
100
+ def forward(self, t):
101
+ # Applied on (batch, seq_len, heads, head_dim) tensors from attention.
102
+ seq_len = t.shape[1]
103
+ freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
104
+ freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
105
+
106
+ return t * freqs_cos[:, None, :] + rotate_half(t) * freqs_sin[:, None, :]
107
+
108
+
109
+ def modulate(x, shift, scale):
110
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
111
+
112
+
113
+ class JiTPatchEmbed(nn.Module):
114
+ """Image to Patch Embedding with Bottleneck"""
115
+
116
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True):
117
+ super().__init__()
118
+ img_size = (img_size, img_size)
119
+ patch_size = (patch_size, patch_size)
120
+ self.img_size = img_size
121
+ self.patch_size = patch_size
122
+ self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
123
+
124
+ self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False)
125
+ self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias)
126
+
127
+ def forward(self, x):
128
+ x = self.proj2(self.proj1(x)).flatten(2).transpose(1, 2)
129
+ return x
130
+
131
+
132
+ class JiTTimestepEmbedder(nn.Module):
133
+ """
134
+ Embeds scalar timesteps into vector representations.
135
+ """
136
+
137
+ def __init__(self, hidden_size, frequency_embedding_size=256):
138
+ super().__init__()
139
+ self.mlp = nn.Sequential(
140
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
141
+ nn.SiLU(),
142
+ nn.Linear(hidden_size, hidden_size, bias=True),
143
+ )
144
+ self.frequency_embedding_size = frequency_embedding_size
145
+
146
+ @staticmethod
147
+ def timestep_embedding(t, dim, max_period=10000):
148
+ """
149
+ Create sinusoidal timestep embeddings.
150
+ """
151
+ half = dim // 2
152
+ freqs = torch.exp(
153
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
154
+ ).to(device=t.device)
155
+ args = t[:, None].float() * freqs[None]
156
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
157
+ if dim % 2:
158
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
159
+ return embedding
160
+
161
+ def forward(self, t, dtype=None):
162
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
163
+ if dtype is not None:
164
+ t_freq = t_freq.to(dtype=dtype)
165
+ t_emb = self.mlp(t_freq)
166
+ return t_emb
167
+
168
+
169
+ class JiTLabelEmbedder(nn.Module):
170
+ """
171
+ Embeds class labels into vector representations.
172
+ """
173
+
174
+ def __init__(self, num_classes, hidden_size):
175
+ super().__init__()
176
+ self.embedding_table = nn.Embedding(num_classes + 1, hidden_size)
177
+ self.num_classes = num_classes
178
+
179
+ def forward(self, labels):
180
+ embeddings = self.embedding_table(labels)
181
+ return embeddings
182
+
183
+
184
+ class JiTAttention(nn.Module):
185
+ def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
186
+ super().__init__()
187
+ self.num_heads = num_heads
188
+ head_dim = dim // num_heads
189
+
190
+ self.q_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
191
+ self.k_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
192
+
193
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
194
+ self.attn_drop = attn_drop
195
+ self.proj = nn.Linear(dim, dim)
196
+ self.proj_drop = nn.Dropout(proj_drop)
197
+
198
+ def forward(self, x, rope=None):
199
+ B, N, C = x.shape
200
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
201
+ q, k, v = qkv[0], qkv[1], qkv[2]
202
+
203
+ q = self.q_norm(q)
204
+ k = self.k_norm(k)
205
+
206
+ if rope is not None:
207
+ q = q.transpose(1, 2)
208
+ k = k.transpose(1, 2)
209
+ q = rope(q)
210
+ k = rope(k)
211
+ q = q.transpose(1, 2)
212
+ k = k.transpose(1, 2)
213
+
214
+ dropout_p = self.attn_drop if self.training else 0.0
215
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
216
+ x = x.transpose(1, 2).reshape(B, N, C)
217
+ x = self.proj(x)
218
+ x = self.proj_drop(x)
219
+ return x
220
+
221
+
222
+ class JiTSwiGLUFFN(nn.Module):
223
+ def __init__(self, dim: int, hidden_dim: int, drop=0.0, bias=True) -> None:
224
+ super().__init__()
225
+ hidden_dim = int(hidden_dim * 2 / 3)
226
+ self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias)
227
+ self.w3 = nn.Linear(hidden_dim, dim, bias=bias)
228
+ self.ffn_dropout = nn.Dropout(drop)
229
+
230
+ def forward(self, x):
231
+ x12 = self.w12(x)
232
+ x1, x2 = x12.chunk(2, dim=-1)
233
+ hidden = F.silu(x1) * x2
234
+ return self.w3(self.ffn_dropout(hidden))
235
+
236
+
237
+ class JiTBlock(nn.Module):
238
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
239
+ super().__init__()
240
+ self.norm1 = RMSNorm(hidden_size, eps=eps)
241
+ self.attn = JiTAttention(
242
+ hidden_size,
243
+ num_heads=num_heads,
244
+ qkv_bias=True,
245
+ qk_norm=True,
246
+ attn_drop=attn_drop,
247
+ proj_drop=proj_drop,
248
+ eps=eps,
249
+ )
250
+ self.norm2 = RMSNorm(hidden_size, eps=eps)
251
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
252
+ self.mlp = JiTSwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop)
253
+
254
+ self.act = nn.SiLU()
255
+ self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
256
+
257
+ def forward(self, x, c, feat_rope=None):
258
+ # Apply activation
259
+ c = self.act(c)
260
+
261
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
262
+
263
+ # Attention block
264
+ norm_x = self.norm1(x)
265
+ modulated_x = modulate(norm_x, shift_msa, scale_msa)
266
+ attn_out = self.attn(modulated_x, rope=feat_rope)
267
+ x = x + gate_msa.unsqueeze(1) * attn_out
268
+
269
+ # MLP block
270
+ norm_x = self.norm2(x)
271
+ modulated_x = modulate(norm_x, shift_mlp, scale_mlp)
272
+ mlp_out = self.mlp(modulated_x)
273
+ x = x + gate_mlp.unsqueeze(1) * mlp_out
274
+
275
+ return x
276
+
277
+
278
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
279
+ grid_h = np.arange(grid_size, dtype=np.float32)
280
+ grid_w = np.arange(grid_size, dtype=np.float32)
281
+ grid = np.meshgrid(grid_w, grid_h)
282
+ grid = np.stack(grid, axis=0)
283
+ grid = grid.reshape([2, 1, grid_size, grid_size])
284
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
285
+ if cls_token and extra_tokens > 0:
286
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
287
+ return pos_embed
288
+
289
+
290
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
291
+ if embed_dim % 2 != 0:
292
+ raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
293
+
294
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
295
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
296
+ emb = np.concatenate([emb_h, emb_w], axis=1)
297
+ return emb
298
+
299
+
300
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
301
+ if embed_dim % 2 != 0:
302
+ raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
303
+
304
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
305
+ omega /= embed_dim / 2.0
306
+ omega = 1.0 / 10000**omega
307
+
308
+ pos = pos.reshape(-1)
309
+ out = np.einsum("m,d->md", pos, omega)
310
+
311
+ emb_sin = np.sin(out)
312
+ emb_cos = np.cos(out)
313
+
314
+ emb = np.concatenate([emb_sin, emb_cos], axis=1)
315
+ return emb
316
+
317
+
318
+ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
319
+ r"""
320
+ A 2D Transformer for pixel-space class-conditional generation with JiT
321
+ ([Back to Basics: Let Denoising Generative Models Denoise](https://arxiv.org/abs/2511.13720)).
322
+
323
+ Parameters:
324
+ sample_size (`int`, defaults to `256`):
325
+ Input image resolution (height and width).
326
+ patch_size (`int`, defaults to `16`):
327
+ Patch size for the bottleneck patch embedder.
328
+ in_channels (`int`, defaults to `3`):
329
+ Number of input image channels.
330
+ hidden_size (`int`, defaults to `768`):
331
+ Transformer hidden dimension.
332
+ num_layers (`int`, defaults to `12`):
333
+ Number of JiT transformer blocks.
334
+ num_attention_heads (`int`, defaults to `12`):
335
+ Number of attention heads per block.
336
+ mlp_ratio (`float`, defaults to `4.0`):
337
+ MLP hidden dimension multiplier.
338
+ attention_dropout (`float`, defaults to `0.0`):
339
+ Attention dropout in the middle quarter of blocks.
340
+ dropout (`float`, defaults to `0.0`):
341
+ Projection dropout in the middle quarter of blocks.
342
+ num_classes (`int`, defaults to `1000`):
343
+ Number of class labels (null label uses index `num_classes` for CFG).
344
+ bottleneck_dim (`int`, defaults to `128`):
345
+ PCA bottleneck dimension in the patch embedder.
346
+ in_context_len (`int`, defaults to `32`):
347
+ Number of in-context class tokens prepended mid-network.
348
+ in_context_start (`int`, defaults to `4`):
349
+ Block index at which in-context tokens are inserted.
350
+ norm_eps (`float`, defaults to `1e-6`):
351
+ Epsilon for RMSNorm layers.
352
+ """
353
+
354
+ _supports_gradient_checkpointing = True
355
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
356
+
357
+ @register_to_config
358
+ def __init__(
359
+ self,
360
+ sample_size: int = 256,
361
+ patch_size: int = 16,
362
+ in_channels: int = 3,
363
+ hidden_size: int = 768,
364
+ num_layers: int = 12,
365
+ num_attention_heads: int = 12,
366
+ mlp_ratio: float = 4.0,
367
+ attention_dropout: float = 0.0,
368
+ dropout: float = 0.0,
369
+ num_classes: int = 1000,
370
+ bottleneck_dim: int = 128,
371
+ in_context_len: int = 32,
372
+ in_context_start: int = 4,
373
+ norm_eps: float = 1e-6,
374
+ ):
375
+ super().__init__()
376
+ self.sample_size = sample_size
377
+ self.patch_size = patch_size
378
+ self.in_channels = in_channels
379
+ self.out_channels = in_channels
380
+ self.hidden_size = hidden_size
381
+ self.num_layers = num_layers
382
+ self.num_attention_heads = num_attention_heads
383
+ self.in_context_len = in_context_len
384
+ self.in_context_start = in_context_start
385
+ self.norm_eps = norm_eps
386
+ self.gradient_checkpointing = False
387
+
388
+ # Time and Class Embedding
389
+ self.t_embedder = JiTTimestepEmbedder(hidden_size)
390
+ self.y_embedder = JiTLabelEmbedder(num_classes, hidden_size)
391
+
392
+ # Patch Embedding
393
+ self.x_embedder = JiTPatchEmbed(
394
+ img_size=sample_size,
395
+ patch_size=patch_size,
396
+ in_chans=in_channels,
397
+ pca_dim=bottleneck_dim,
398
+ embed_dim=hidden_size,
399
+ bias=True,
400
+ )
401
+
402
+ # Positional Embedding (Fixed Sin-Cos)
403
+ num_patches = self.x_embedder.num_patches
404
+ pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches**0.5))
405
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
406
+
407
+ # In-context Embedding
408
+ if self.in_context_len > 0:
409
+ self.in_context_posemb = nn.Parameter(torch.zeros(1, self.in_context_len, hidden_size))
410
+
411
+ # RoPE
412
+ half_head_dim = hidden_size // num_attention_heads // 2
413
+ hw_seq_len = sample_size // patch_size
414
+ self.feat_rope = JiTRotaryEmbedding(dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=0)
415
+ self.feat_rope_incontext = JiTRotaryEmbedding(
416
+ dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=self.in_context_len
417
+ )
418
+
419
+ # Blocks
420
+ self.blocks = nn.ModuleList(
421
+ [
422
+ JiTBlock(
423
+ hidden_size,
424
+ num_attention_heads,
425
+ mlp_ratio=mlp_ratio,
426
+ attn_drop=attention_dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
427
+ proj_drop=dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
428
+ eps=norm_eps,
429
+ )
430
+ for i in range(num_layers)
431
+ ]
432
+ )
433
+
434
+ # Final Layer
435
+ self.norm_final = RMSNorm(hidden_size, eps=norm_eps)
436
+ self.linear_final = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
437
+ self.act_final = nn.SiLU()
438
+ self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
439
+
440
+ def forward(
441
+ self,
442
+ hidden_states: torch.Tensor,
443
+ timestep: torch.LongTensor,
444
+ class_labels: torch.LongTensor,
445
+ return_dict: bool = True,
446
+ ):
447
+
448
+ t_emb = self.t_embedder(timestep, dtype=hidden_states.dtype)
449
+ y_emb = self.y_embedder(class_labels)
450
+
451
+ # Ensure embeddings match hidden_states dtype
452
+ y_emb = y_emb.to(dtype=hidden_states.dtype)
453
+
454
+ c = t_emb + y_emb
455
+
456
+ # Patch Embed
457
+ x = self.x_embedder(hidden_states)
458
+ x = x + self.pos_embed.to(x.dtype)
459
+
460
+ # Blocks
461
+ for i, block in enumerate(self.blocks):
462
+ if self.in_context_len > 0 and i == self.in_context_start:
463
+ in_context_tokens = y_emb.unsqueeze(1).repeat(1, self.in_context_len, 1)
464
+ in_context_tokens = in_context_tokens + self.in_context_posemb.to(in_context_tokens.dtype)
465
+ x = torch.cat([in_context_tokens, x], dim=1)
466
+
467
+ rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
468
+
469
+ if self.training and self.gradient_checkpointing:
470
+ x = torch.utils.checkpoint.checkpoint(
471
+ block,
472
+ x,
473
+ c,
474
+ rope,
475
+ use_reentrant=False,
476
+ )
477
+ else:
478
+ x = block(x, c, feat_rope=rope)
479
+
480
+ # Slice off in-context tokens
481
+ if self.in_context_len > 0:
482
+ x = x[:, self.in_context_len :]
483
+
484
+ # Final Layer
485
+ c = self.act_final(c)
486
+ shift, scale = self.adaLN_modulation_final(c).chunk(2, dim=1)
487
+
488
+ x = modulate(self.norm_final(x), shift, scale)
489
+ x = self.linear_final(x)
490
+
491
+ # Unpatchify
492
+ h = w = int(x.shape[1] ** 0.5)
493
+ x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels))
494
+ x = torch.einsum("nhwpqc->nchpwq", x)
495
+ output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size))
496
+
497
+ if not return_dict:
498
+ return (output,)
499
+
500
+ return Transformer2DModelOutput(sample=output)
JiT-L-16/model_index.json CHANGED
@@ -1,8 +1,15 @@
1
  {
2
- "_class_name": "JiTPipeline",
 
 
 
3
  "_diffusers_version": "0.36.0",
 
 
 
 
4
  "transformer": [
5
- "jit_diffusers",
6
  "JiTTransformer2DModel"
7
  ]
8
  }
 
1
  {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "JiTPipeline"
5
+ ],
6
  "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "scheduling_jit",
9
+ "JiTScheduler"
10
+ ],
11
  "transformer": [
12
+ "jit_transformer_2d",
13
  "JiTTransformer2DModel"
14
  ]
15
  }
JiT-L-16/pipeline.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import importlib
18
+ import json
19
+ import sys
20
+ from pathlib import Path
21
+ from typing import Dict, List, Optional, Tuple, Union
22
+
23
+ import torch
24
+
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
26
+ from diffusers.utils.torch_utils import randn_tensor
27
+
28
+
29
+ RECOMMENDED_NOISE_BY_SIZE = {
30
+ 256: 1.0,
31
+ 512: 2.0,
32
+ }
33
+
34
+
35
+ class JiTPipeline(DiffusionPipeline):
36
+ r"""
37
+ Pipeline for image generation using JiT (Just image Transformer).
38
+
39
+ Parameters:
40
+ transformer ([`JiTTransformer2DModel`]):
41
+ A class-conditioned `JiTTransformer2DModel` to denoise the images.
42
+ scheduler ([`JiTScheduler`]):
43
+ Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler).
44
+ id2label (`dict[int, str]`, *optional*):
45
+ ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
46
+ id2label_cn (`dict[int, str]`, *optional*):
47
+ ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms.
48
+ """
49
+
50
+ model_cpu_offload_seq = "transformer"
51
+
52
+ @classmethod
53
+ def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
54
+ """Load a self-contained variant folder locally or from the Hub.
55
+
56
+ Examples:
57
+ JiTPipeline.from_pretrained(".")
58
+ JiTPipeline.from_pretrained("./JiT-H-32")
59
+ DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
60
+ """
61
+ repo_root = Path(__file__).resolve().parent
62
+
63
+ if pretrained_model_name_or_path in (None, "", "."):
64
+ variant = repo_root
65
+ elif (
66
+ isinstance(pretrained_model_name_or_path, str)
67
+ and "/" in pretrained_model_name_or_path
68
+ and not Path(pretrained_model_name_or_path).exists()
69
+ ):
70
+ from huggingface_hub import snapshot_download
71
+
72
+ hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
73
+ if subfolder:
74
+ hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"])
75
+ cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
76
+ variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
77
+ else:
78
+ variant = Path(pretrained_model_name_or_path)
79
+ if not variant.is_absolute():
80
+ candidate = (Path.cwd() / variant).resolve()
81
+ variant = candidate if candidate.exists() else (repo_root / variant).resolve()
82
+ if subfolder:
83
+ variant = variant / subfolder
84
+
85
+ model_kwargs = dict(kwargs)
86
+ inserted: List[str] = []
87
+
88
+ def _load_component(folder: str, module_name: str, class_name: str):
89
+ comp_dir = variant / folder
90
+ module_path = comp_dir / f"{module_name}.py"
91
+ has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
92
+ if not module_path.exists() or not has_weights:
93
+ return None
94
+
95
+ comp_path = str(comp_dir)
96
+ if comp_path not in sys.path:
97
+ sys.path.insert(0, comp_path)
98
+ inserted.append(comp_path)
99
+
100
+ module = importlib.import_module(module_name)
101
+ component_cls = getattr(module, class_name)
102
+ return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
103
+
104
+ try:
105
+ transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
106
+ scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler")
107
+
108
+ if transformer is None:
109
+ raise ValueError(f"No loadable transformer found under {variant}")
110
+
111
+ variant_path = str(variant)
112
+ id2label, id2label_cn = cls._load_labels_for_variant(variant_path)
113
+
114
+ pipe = cls(
115
+ transformer=transformer,
116
+ scheduler=scheduler,
117
+ id2label=id2label,
118
+ id2label_cn=id2label_cn,
119
+ )
120
+ if variant_path and hasattr(pipe, "register_to_config"):
121
+ pipe.register_to_config(_name_or_path=variant_path)
122
+ return pipe
123
+ finally:
124
+ for comp_path in inserted:
125
+ if comp_path in sys.path:
126
+ sys.path.remove(comp_path)
127
+
128
+ def __init__(
129
+ self,
130
+ transformer,
131
+ scheduler,
132
+ id2label: Optional[Dict[int, str]] = None,
133
+ id2label_cn: Optional[Dict[int, str]] = None,
134
+ ):
135
+ super().__init__()
136
+ self.register_modules(transformer=transformer, scheduler=scheduler)
137
+
138
+ self._id2label = id2label or {}
139
+ self._id2label_cn = id2label_cn or {}
140
+ self.labels = self._build_label2id(self._id2label)
141
+ self.labels_cn = self._build_label2id(self._id2label_cn)
142
+
143
+ def _ensure_labels_loaded(self) -> None:
144
+ if self._id2label or self._id2label_cn:
145
+ return
146
+ loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None))
147
+ if loaded_en:
148
+ self._id2label = loaded_en
149
+ self.labels = self._build_label2id(self._id2label)
150
+ if loaded_cn:
151
+ self._id2label_cn = loaded_cn
152
+ self.labels_cn = self._build_label2id(self._id2label_cn)
153
+
154
+ @staticmethod
155
+ def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]:
156
+ if not variant_path:
157
+ return None
158
+ variant_dir = Path(variant_path).resolve()
159
+ labels_dir = variant_dir.parent / "labels"
160
+ return labels_dir if labels_dir.is_dir() else None
161
+
162
+ @staticmethod
163
+ def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]:
164
+ filename = "id2label_en.json" if lang == "en" else "id2label_cn.json"
165
+ path = labels_dir / filename
166
+ if not path.exists():
167
+ raise FileNotFoundError(path)
168
+ raw = json.loads(path.read_text(encoding="utf-8"))
169
+ return {int(key): value for key, value in raw.items()}
170
+
171
+ @classmethod
172
+ def _load_labels_for_variant(
173
+ cls,
174
+ variant_path: Optional[str],
175
+ ) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]:
176
+ labels_dir = cls._labels_dir_for_variant(variant_path)
177
+ if labels_dir is None:
178
+ return None, None
179
+ try:
180
+ return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn")
181
+ except FileNotFoundError:
182
+ return None, None
183
+
184
+ @staticmethod
185
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
186
+ label2id: Dict[str, int] = {}
187
+ for class_id, value in id2label.items():
188
+ for synonym in value.split(","):
189
+ synonym = synonym.strip()
190
+ if synonym:
191
+ label2id[synonym] = int(class_id)
192
+ return dict(sorted(label2id.items()))
193
+
194
+ @property
195
+ def id2label(self) -> Dict[int, str]:
196
+ """ImageNet class id to English label string (comma-separated synonyms)."""
197
+ self._ensure_labels_loaded()
198
+ return self._id2label
199
+
200
+ @property
201
+ def id2label_cn(self) -> Dict[int, str]:
202
+ """ImageNet class id to Chinese label string (comma-separated synonyms)."""
203
+ self._ensure_labels_loaded()
204
+ return self._id2label_cn
205
+
206
+ def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]:
207
+ r"""
208
+ Map ImageNet label strings to class ids.
209
+
210
+ Args:
211
+ label (`str` or `list[str]`):
212
+ One or more label strings. Each string must match a synonym in `id2label` (English)
213
+ or `id2label_cn` (Chinese).
214
+ lang (`str`, *optional*, defaults to `"en"`):
215
+ `"en"` uses English synonyms; `"cn"` uses Chinese synonyms.
216
+ """
217
+ if lang not in ("en", "cn"):
218
+ raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.")
219
+
220
+ self._ensure_labels_loaded()
221
+ label2id = self.labels if lang == "en" else self.labels_cn
222
+ if not label2id:
223
+ raise ValueError(
224
+ f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
225
+ )
226
+
227
+ if isinstance(label, str):
228
+ label = [label]
229
+
230
+ missing = [item for item in label if item not in label2id]
231
+ if missing:
232
+ preview = ", ".join(list(label2id.keys())[:8])
233
+ raise ValueError(
234
+ f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
235
+ )
236
+ return [label2id[item] for item in label]
237
+
238
+ def _normalize_class_labels(
239
+ self,
240
+ class_labels: Union[int, str, List[Union[int, str]]],
241
+ ) -> List[int]:
242
+ if isinstance(class_labels, int):
243
+ return [class_labels]
244
+
245
+ if isinstance(class_labels, str):
246
+ return self.get_label_ids(class_labels)
247
+
248
+ if class_labels and isinstance(class_labels[0], str):
249
+ self._ensure_labels_loaded()
250
+ if all(label in self.labels for label in class_labels):
251
+ return self.get_label_ids(class_labels, lang="en")
252
+ if all(label in self.labels_cn for label in class_labels):
253
+ return self.get_label_ids(class_labels, lang="cn")
254
+ raise ValueError(
255
+ "Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` "
256
+ "or Chinese synonyms from `pipe.labels_cn`."
257
+ )
258
+
259
+ return list(class_labels)
260
+
261
+ def _predict_velocity(
262
+ self,
263
+ z_value: torch.Tensor,
264
+ t: torch.Tensor,
265
+ class_labels: torch.Tensor,
266
+ class_null: torch.Tensor,
267
+ do_classifier_free_guidance: bool,
268
+ guidance_scale: float,
269
+ guidance_interval_min: float,
270
+ guidance_interval_max: float,
271
+ ) -> torch.Tensor:
272
+ t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype)
273
+ if do_classifier_free_guidance:
274
+ z_in = torch.cat([z_value, z_value], dim=0)
275
+ labels = torch.cat([class_labels, class_null], dim=0)
276
+ else:
277
+ z_in = z_value
278
+ labels = class_labels
279
+
280
+ t_batch = t.flatten().expand(z_in.shape[0])
281
+ x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample
282
+ v = self.scheduler.velocity_from_prediction(z_in, x_pred, t)
283
+
284
+ if not do_classifier_free_guidance:
285
+ return v
286
+
287
+ v_cond, v_uncond = v.chunk(2, dim=0)
288
+ interval_mask = t < guidance_interval_max
289
+ if guidance_interval_min != 0.0:
290
+ interval_mask = interval_mask & (t > guidance_interval_min)
291
+ scale = torch.where(
292
+ interval_mask,
293
+ torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype),
294
+ torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype),
295
+ )
296
+ return v_uncond + scale * (v_cond - v_uncond)
297
+
298
+ def _run_sampler(
299
+ self,
300
+ latents: torch.Tensor,
301
+ class_labels: torch.Tensor,
302
+ class_null: torch.Tensor,
303
+ num_inference_steps: int,
304
+ do_classifier_free_guidance: bool,
305
+ guidance_scale: float,
306
+ guidance_interval_min: float,
307
+ guidance_interval_max: float,
308
+ sampling_method: str,
309
+ ) -> torch.Tensor:
310
+ device = latents.device
311
+ self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method)
312
+ timesteps = self.scheduler.timesteps
313
+
314
+ for i in self.progress_bar(range(num_inference_steps - 1)):
315
+ t = timesteps[i]
316
+ t_next = timesteps[i + 1]
317
+ v = self._predict_velocity(
318
+ latents,
319
+ t,
320
+ class_labels,
321
+ class_null,
322
+ do_classifier_free_guidance,
323
+ guidance_scale,
324
+ guidance_interval_min,
325
+ guidance_interval_max,
326
+ )
327
+
328
+ if sampling_method == "heun":
329
+ latents_euler = latents + (t_next - t) * v
330
+ v_next = self._predict_velocity(
331
+ latents_euler,
332
+ t_next,
333
+ class_labels,
334
+ class_null,
335
+ do_classifier_free_guidance,
336
+ guidance_scale,
337
+ guidance_interval_min,
338
+ guidance_interval_max,
339
+ )
340
+ latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample
341
+ else:
342
+ latents = self.scheduler.step(v, t, latents).prev_sample
343
+
344
+ t = timesteps[-2]
345
+ t_next = timesteps[-1]
346
+ v = self._predict_velocity(
347
+ latents,
348
+ t,
349
+ class_labels,
350
+ class_null,
351
+ do_classifier_free_guidance,
352
+ guidance_scale,
353
+ guidance_interval_min,
354
+ guidance_interval_max,
355
+ )
356
+ return latents + (t_next - t) * v
357
+
358
+ @torch.inference_mode()
359
+ def __call__(
360
+ self,
361
+ class_labels: Union[int, str, List[Union[int, str]]],
362
+ guidance_scale: Optional[float] = None,
363
+ guidance_interval_min: float = 0.1,
364
+ guidance_interval_max: float = 1.0,
365
+ noise_scale: Optional[float] = None,
366
+ t_eps: Optional[float] = None,
367
+ sampling_method: Optional[str] = None,
368
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
369
+ num_inference_steps: int = 50,
370
+ output_type: Optional[str] = "pil",
371
+ return_dict: bool = True,
372
+ ) -> Union[ImagePipelineOutput, Tuple]:
373
+ r"""
374
+ Generate class-conditional images.
375
+
376
+ Args:
377
+ class_labels (`int`, `str`, `list[int]`, or `list[str]`):
378
+ ImageNet class indices or human-readable label strings (English or Chinese).
379
+ guidance_scale (`float`, *optional*):
380
+ Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
381
+ guidance_interval_min (`float`, defaults to `0.1`):
382
+ Lower bound of the CFG interval in flow time `t in [0, 1]`.
383
+ guidance_interval_max (`float`, defaults to `1.0`):
384
+ Upper bound of the CFG interval in flow time.
385
+ noise_scale (`float`, *optional*):
386
+ Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
387
+ t_eps (`float`, *optional*):
388
+ Epsilon clamp for the `1 - t` denominator (scheduler config by default).
389
+ sampling_method (`str`, *optional*):
390
+ `"heun"` or `"euler"`. Defaults to the scheduler config (`heun`).
391
+ generator (`torch.Generator`, *optional*):
392
+ RNG for reproducibility.
393
+ num_inference_steps (`int`, defaults to `50`):
394
+ Number of solver steps (at least 2).
395
+ output_type (`str`, *optional*, defaults to `"pil"`):
396
+ `"pil"`, `"np"`, or `"pt"`.
397
+ return_dict (`bool`, *optional*, defaults to `True`):
398
+ Return [`ImagePipelineOutput`] if True.
399
+ """
400
+ solver = sampling_method or self.scheduler.config.solver
401
+ if solver not in {"heun", "euler"}:
402
+ raise ValueError("sampling_method must be one of: 'heun', 'euler'.")
403
+ if num_inference_steps < 2:
404
+ raise ValueError("num_inference_steps must be >= 2.")
405
+
406
+ if t_eps is not None:
407
+ self.scheduler.register_to_config(t_eps=t_eps)
408
+
409
+ class_label_ids = self._normalize_class_labels(class_labels)
410
+ do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
411
+
412
+ batch_size = len(class_label_ids)
413
+ image_size = int(self.transformer.config.sample_size)
414
+ channels = int(self.transformer.config.in_channels)
415
+ null_class_val = int(self.transformer.config.num_classes)
416
+
417
+ if guidance_scale is None:
418
+ guidance_scale = 1.0
419
+ if noise_scale is None:
420
+ noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0)
421
+
422
+ latents = (
423
+ randn_tensor(
424
+ shape=(batch_size, channels, image_size, image_size),
425
+ generator=generator,
426
+ device=self._execution_device,
427
+ dtype=self.transformer.dtype,
428
+ )
429
+ * noise_scale
430
+ )
431
+
432
+ class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
433
+ class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
434
+ class_null = torch.full_like(class_labels_t, null_class_val)
435
+
436
+ latents = self._run_sampler(
437
+ latents,
438
+ class_labels_t,
439
+ class_null,
440
+ num_inference_steps,
441
+ do_classifier_free_guidance,
442
+ guidance_scale,
443
+ guidance_interval_min,
444
+ guidance_interval_max,
445
+ solver,
446
+ )
447
+
448
+ images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
449
+ if output_type == "pt":
450
+ images = images_pt
451
+ elif output_type == "np":
452
+ images = images_pt.permute(0, 2, 3, 1).numpy()
453
+ else:
454
+ images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy())
455
+
456
+ self.maybe_free_model_hooks()
457
+
458
+ if not return_dict:
459
+ return (images,)
460
+ return ImagePipelineOutput(images=images)
JiT-L-16/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "JiTScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "t_eps": 0.05,
6
+ "solver": "heun"
7
+ }
JiT-L-16/scheduler/scheduling_jit.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
22
+ from diffusers.utils import BaseOutput
23
+
24
+
25
+ @dataclass
26
+ class JiTSchedulerOutput(BaseOutput):
27
+ """
28
+ Output class for the JiT scheduler's `step` function.
29
+
30
+ Args:
31
+ prev_sample (`torch.Tensor`):
32
+ Updated sample after one solver step along the JiT flow-time grid.
33
+ """
34
+
35
+ prev_sample: torch.Tensor
36
+
37
+
38
+ class JiTScheduler(SchedulerMixin, ConfigMixin):
39
+ """
40
+ Manual flow-matching scheduler for JiT checkpoints.
41
+
42
+ Uses a linear flow-time grid `t in [0, 1]` (increasing), matching the official JiT
43
+ sampler. Velocity is `v = (x_pred - z) / (1 - t)`; integration is explicit Euler or
44
+ Heun along that grid.
45
+ """
46
+
47
+ order = 2
48
+
49
+ @register_to_config
50
+ def __init__(
51
+ self,
52
+ num_train_timesteps: int = 1000,
53
+ t_eps: float = 5e-2,
54
+ solver: str = "heun",
55
+ ):
56
+ if solver not in {"heun", "euler"}:
57
+ raise ValueError("solver must be one of: 'heun', 'euler'.")
58
+ self.timesteps: Optional[torch.Tensor] = None
59
+ self.sigmas: Optional[List[float]] = None
60
+ self.num_inference_steps: Optional[int] = None
61
+ self._step_index: Optional[int] = None
62
+
63
+ @property
64
+ def init_noise_sigma(self) -> float:
65
+ return 1.0
66
+
67
+ def set_timesteps(
68
+ self,
69
+ num_inference_steps: int,
70
+ device: Union[str, torch.device, None] = None,
71
+ solver: Optional[str] = None,
72
+ ) -> None:
73
+ if num_inference_steps < 2:
74
+ raise ValueError("num_inference_steps must be >= 2.")
75
+
76
+ self.num_inference_steps = num_inference_steps
77
+ self.timesteps = torch.linspace(
78
+ 0.0,
79
+ 1.0,
80
+ num_inference_steps + 1,
81
+ device=device,
82
+ dtype=torch.float32,
83
+ )
84
+ sigma_grid = torch.linspace(0.0, 1.0, num_inference_steps, device=device, dtype=torch.float32)
85
+ self.sigmas = (1.0 - sigma_grid).tolist()
86
+ self._step_index = 0
87
+ if solver is not None:
88
+ self.register_to_config(solver=solver)
89
+
90
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
91
+ del timestep
92
+ return sample
93
+
94
+ def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int:
95
+ if self._step_index is not None:
96
+ return self._step_index
97
+ if self.timesteps is None:
98
+ raise ValueError("Call `set_timesteps` before `step`.")
99
+ if timestep is None:
100
+ return 0
101
+ t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0])
102
+ matches = (self.timesteps - t_value).abs() < 1e-6
103
+ if matches.any():
104
+ return int(matches.nonzero(as_tuple=False)[0].item())
105
+ return 0
106
+
107
+ def step(
108
+ self,
109
+ model_output: torch.Tensor,
110
+ timestep: Union[float, torch.Tensor, None],
111
+ sample: torch.Tensor,
112
+ model_output_next: Optional[torch.Tensor] = None,
113
+ return_dict: bool = True,
114
+ ) -> Union[JiTSchedulerOutput, Tuple[torch.Tensor]]:
115
+ """
116
+ Integrate one step on the linear `t` grid.
117
+
118
+ Args:
119
+ model_output (`torch.Tensor`):
120
+ Velocity `v = (x_pred - z) / (1 - t)` at the current time.
121
+ timestep (`float` or `torch.Tensor`, *optional*):
122
+ Current flow time `t`. When omitted, uses the internal step index.
123
+ sample (`torch.Tensor`):
124
+ Current noisy latent `z`.
125
+ model_output_next (`torch.Tensor`, *optional*):
126
+ Velocity at `t_next` (required for Heun intermediate steps).
127
+ """
128
+ if self.timesteps is None:
129
+ raise ValueError("Call `set_timesteps` before `step`.")
130
+
131
+ step_index = self._resolve_step_index(timestep)
132
+ if step_index >= len(self.timesteps) - 1:
133
+ raise ValueError("Scheduler has already reached the final timestep.")
134
+
135
+ t = self.timesteps[step_index]
136
+ t_next = self.timesteps[step_index + 1]
137
+ dt = t_next - t
138
+
139
+ if self.config.solver == "heun" and model_output_next is not None:
140
+ prev_sample = sample + dt * 0.5 * (model_output + model_output_next)
141
+ else:
142
+ prev_sample = sample + dt * model_output
143
+
144
+ self._step_index = step_index + 1
145
+
146
+ if not return_dict:
147
+ return (prev_sample,)
148
+ return JiTSchedulerOutput(prev_sample=prev_sample)
149
+
150
+ def velocity_from_prediction(
151
+ self,
152
+ sample: torch.Tensor,
153
+ x_pred: torch.Tensor,
154
+ timestep: Union[float, torch.Tensor],
155
+ ) -> torch.Tensor:
156
+ """Compute JiT velocity `v = (x_pred - z) / (1 - t)` with denominator clamp."""
157
+ t = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype)
158
+ while t.ndim < sample.ndim:
159
+ t = t.unsqueeze(-1)
160
+ denom = (1.0 - t).clamp_min(self.config.t_eps)
161
+ return (x_pred - sample) / denom
JiT-L-16/transformer/config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "JiTTransformer2DModel",
3
+ "_diffusers_version": "0.36.0",
4
+ "attention_dropout": 0.0,
5
+ "bottleneck_dim": 128,
6
+ "dropout": 0.0,
7
+ "hidden_size": 1024,
8
+ "in_channels": 3,
9
+ "in_context_len": 32,
10
+ "in_context_start": 8,
11
+ "mlp_ratio": 4.0,
12
+ "norm_eps": 1e-06,
13
+ "num_attention_heads": 16,
14
+ "num_classes": 1000,
15
+ "num_layers": 24,
16
+ "patch_size": 16,
17
+ "sample_size": 256
18
+ }
JiT-L-16/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9285393d92db078237e8adc552d6c9314c898c710ca1dfb4d3503fda0016b0f
3
+ size 1836593656
JiT-L-16/transformer/jit_transformer_2d.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
24
+ from diffusers.models.modeling_utils import ModelMixin
25
+ from diffusers.models.normalization import RMSNorm
26
+ from diffusers.utils import logging
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+
32
+ def broadcat(tensors, dim=-1):
33
+ num_tensors = len(tensors)
34
+ shape_lens = {len(t.shape) for t in tensors}
35
+ if len(shape_lens) != 1:
36
+ raise ValueError("tensors must all have the same number of dimensions")
37
+ shape_len = list(shape_lens)[0]
38
+ dim = (dim + shape_len) if dim < 0 else dim
39
+ dims = list(zip(*(list(t.shape) for t in tensors)))
40
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
41
+
42
+ if not all(len(set(t[1])) <= 2 for t in expandable_dims):
43
+ raise ValueError("invalid dimensions for broadcastable concatenation")
44
+
45
+ max_dims = [(t[0], max(t[1])) for t in expandable_dims]
46
+ expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
47
+ expanded_dims.insert(dim, (dim, dims[dim]))
48
+ expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
49
+ tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
50
+ return torch.cat(tensors, dim=dim)
51
+
52
+
53
+ def rotate_half(x):
54
+ x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2)
55
+ x1, x2 = x.unbind(dim=-1)
56
+ x = torch.stack((-x2, x1), dim=-1)
57
+ return x.view(*x.shape[:-2], -1)
58
+
59
+
60
+ class JiTRotaryEmbedding(nn.Module):
61
+ def __init__(
62
+ self,
63
+ dim,
64
+ pt_seq_len=16,
65
+ ft_seq_len=None,
66
+ custom_freqs=None,
67
+ theta=10000,
68
+ num_cls_token=0,
69
+ ):
70
+ super().__init__()
71
+ if custom_freqs is not None:
72
+ freqs = custom_freqs
73
+ else:
74
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
75
+
76
+ if ft_seq_len is None:
77
+ ft_seq_len = pt_seq_len
78
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
79
+
80
+ freqs = torch.einsum("..., f -> ... f", t, freqs)
81
+ freqs = freqs.repeat_interleave(2, dim=-1)
82
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
83
+
84
+ if num_cls_token > 0:
85
+ freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D]
86
+ cos_img = freqs_flat.cos()
87
+ sin_img = freqs_flat.sin()
88
+
89
+ # prepend in-context cls token
90
+ _, D = cos_img.shape
91
+ cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype)
92
+ sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype)
93
+
94
+ self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False)
95
+ self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False)
96
+ else:
97
+ self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False)
98
+ self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False)
99
+
100
+ def forward(self, t):
101
+ # Applied on (batch, seq_len, heads, head_dim) tensors from attention.
102
+ seq_len = t.shape[1]
103
+ freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
104
+ freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
105
+
106
+ return t * freqs_cos[:, None, :] + rotate_half(t) * freqs_sin[:, None, :]
107
+
108
+
109
+ def modulate(x, shift, scale):
110
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
111
+
112
+
113
+ class JiTPatchEmbed(nn.Module):
114
+ """Image to Patch Embedding with Bottleneck"""
115
+
116
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True):
117
+ super().__init__()
118
+ img_size = (img_size, img_size)
119
+ patch_size = (patch_size, patch_size)
120
+ self.img_size = img_size
121
+ self.patch_size = patch_size
122
+ self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
123
+
124
+ self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False)
125
+ self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias)
126
+
127
+ def forward(self, x):
128
+ x = self.proj2(self.proj1(x)).flatten(2).transpose(1, 2)
129
+ return x
130
+
131
+
132
+ class JiTTimestepEmbedder(nn.Module):
133
+ """
134
+ Embeds scalar timesteps into vector representations.
135
+ """
136
+
137
+ def __init__(self, hidden_size, frequency_embedding_size=256):
138
+ super().__init__()
139
+ self.mlp = nn.Sequential(
140
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
141
+ nn.SiLU(),
142
+ nn.Linear(hidden_size, hidden_size, bias=True),
143
+ )
144
+ self.frequency_embedding_size = frequency_embedding_size
145
+
146
+ @staticmethod
147
+ def timestep_embedding(t, dim, max_period=10000):
148
+ """
149
+ Create sinusoidal timestep embeddings.
150
+ """
151
+ half = dim // 2
152
+ freqs = torch.exp(
153
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
154
+ ).to(device=t.device)
155
+ args = t[:, None].float() * freqs[None]
156
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
157
+ if dim % 2:
158
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
159
+ return embedding
160
+
161
+ def forward(self, t, dtype=None):
162
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
163
+ if dtype is not None:
164
+ t_freq = t_freq.to(dtype=dtype)
165
+ t_emb = self.mlp(t_freq)
166
+ return t_emb
167
+
168
+
169
+ class JiTLabelEmbedder(nn.Module):
170
+ """
171
+ Embeds class labels into vector representations.
172
+ """
173
+
174
+ def __init__(self, num_classes, hidden_size):
175
+ super().__init__()
176
+ self.embedding_table = nn.Embedding(num_classes + 1, hidden_size)
177
+ self.num_classes = num_classes
178
+
179
+ def forward(self, labels):
180
+ embeddings = self.embedding_table(labels)
181
+ return embeddings
182
+
183
+
184
+ class JiTAttention(nn.Module):
185
+ def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
186
+ super().__init__()
187
+ self.num_heads = num_heads
188
+ head_dim = dim // num_heads
189
+
190
+ self.q_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
191
+ self.k_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
192
+
193
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
194
+ self.attn_drop = attn_drop
195
+ self.proj = nn.Linear(dim, dim)
196
+ self.proj_drop = nn.Dropout(proj_drop)
197
+
198
+ def forward(self, x, rope=None):
199
+ B, N, C = x.shape
200
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
201
+ q, k, v = qkv[0], qkv[1], qkv[2]
202
+
203
+ q = self.q_norm(q)
204
+ k = self.k_norm(k)
205
+
206
+ if rope is not None:
207
+ q = q.transpose(1, 2)
208
+ k = k.transpose(1, 2)
209
+ q = rope(q)
210
+ k = rope(k)
211
+ q = q.transpose(1, 2)
212
+ k = k.transpose(1, 2)
213
+
214
+ dropout_p = self.attn_drop if self.training else 0.0
215
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
216
+ x = x.transpose(1, 2).reshape(B, N, C)
217
+ x = self.proj(x)
218
+ x = self.proj_drop(x)
219
+ return x
220
+
221
+
222
+ class JiTSwiGLUFFN(nn.Module):
223
+ def __init__(self, dim: int, hidden_dim: int, drop=0.0, bias=True) -> None:
224
+ super().__init__()
225
+ hidden_dim = int(hidden_dim * 2 / 3)
226
+ self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias)
227
+ self.w3 = nn.Linear(hidden_dim, dim, bias=bias)
228
+ self.ffn_dropout = nn.Dropout(drop)
229
+
230
+ def forward(self, x):
231
+ x12 = self.w12(x)
232
+ x1, x2 = x12.chunk(2, dim=-1)
233
+ hidden = F.silu(x1) * x2
234
+ return self.w3(self.ffn_dropout(hidden))
235
+
236
+
237
+ class JiTBlock(nn.Module):
238
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
239
+ super().__init__()
240
+ self.norm1 = RMSNorm(hidden_size, eps=eps)
241
+ self.attn = JiTAttention(
242
+ hidden_size,
243
+ num_heads=num_heads,
244
+ qkv_bias=True,
245
+ qk_norm=True,
246
+ attn_drop=attn_drop,
247
+ proj_drop=proj_drop,
248
+ eps=eps,
249
+ )
250
+ self.norm2 = RMSNorm(hidden_size, eps=eps)
251
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
252
+ self.mlp = JiTSwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop)
253
+
254
+ self.act = nn.SiLU()
255
+ self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
256
+
257
+ def forward(self, x, c, feat_rope=None):
258
+ # Apply activation
259
+ c = self.act(c)
260
+
261
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
262
+
263
+ # Attention block
264
+ norm_x = self.norm1(x)
265
+ modulated_x = modulate(norm_x, shift_msa, scale_msa)
266
+ attn_out = self.attn(modulated_x, rope=feat_rope)
267
+ x = x + gate_msa.unsqueeze(1) * attn_out
268
+
269
+ # MLP block
270
+ norm_x = self.norm2(x)
271
+ modulated_x = modulate(norm_x, shift_mlp, scale_mlp)
272
+ mlp_out = self.mlp(modulated_x)
273
+ x = x + gate_mlp.unsqueeze(1) * mlp_out
274
+
275
+ return x
276
+
277
+
278
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
279
+ grid_h = np.arange(grid_size, dtype=np.float32)
280
+ grid_w = np.arange(grid_size, dtype=np.float32)
281
+ grid = np.meshgrid(grid_w, grid_h)
282
+ grid = np.stack(grid, axis=0)
283
+ grid = grid.reshape([2, 1, grid_size, grid_size])
284
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
285
+ if cls_token and extra_tokens > 0:
286
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
287
+ return pos_embed
288
+
289
+
290
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
291
+ if embed_dim % 2 != 0:
292
+ raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
293
+
294
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
295
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
296
+ emb = np.concatenate([emb_h, emb_w], axis=1)
297
+ return emb
298
+
299
+
300
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
301
+ if embed_dim % 2 != 0:
302
+ raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
303
+
304
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
305
+ omega /= embed_dim / 2.0
306
+ omega = 1.0 / 10000**omega
307
+
308
+ pos = pos.reshape(-1)
309
+ out = np.einsum("m,d->md", pos, omega)
310
+
311
+ emb_sin = np.sin(out)
312
+ emb_cos = np.cos(out)
313
+
314
+ emb = np.concatenate([emb_sin, emb_cos], axis=1)
315
+ return emb
316
+
317
+
318
+ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
319
+ r"""
320
+ A 2D Transformer for pixel-space class-conditional generation with JiT
321
+ ([Back to Basics: Let Denoising Generative Models Denoise](https://arxiv.org/abs/2511.13720)).
322
+
323
+ Parameters:
324
+ sample_size (`int`, defaults to `256`):
325
+ Input image resolution (height and width).
326
+ patch_size (`int`, defaults to `16`):
327
+ Patch size for the bottleneck patch embedder.
328
+ in_channels (`int`, defaults to `3`):
329
+ Number of input image channels.
330
+ hidden_size (`int`, defaults to `768`):
331
+ Transformer hidden dimension.
332
+ num_layers (`int`, defaults to `12`):
333
+ Number of JiT transformer blocks.
334
+ num_attention_heads (`int`, defaults to `12`):
335
+ Number of attention heads per block.
336
+ mlp_ratio (`float`, defaults to `4.0`):
337
+ MLP hidden dimension multiplier.
338
+ attention_dropout (`float`, defaults to `0.0`):
339
+ Attention dropout in the middle quarter of blocks.
340
+ dropout (`float`, defaults to `0.0`):
341
+ Projection dropout in the middle quarter of blocks.
342
+ num_classes (`int`, defaults to `1000`):
343
+ Number of class labels (null label uses index `num_classes` for CFG).
344
+ bottleneck_dim (`int`, defaults to `128`):
345
+ PCA bottleneck dimension in the patch embedder.
346
+ in_context_len (`int`, defaults to `32`):
347
+ Number of in-context class tokens prepended mid-network.
348
+ in_context_start (`int`, defaults to `4`):
349
+ Block index at which in-context tokens are inserted.
350
+ norm_eps (`float`, defaults to `1e-6`):
351
+ Epsilon for RMSNorm layers.
352
+ """
353
+
354
+ _supports_gradient_checkpointing = True
355
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
356
+
357
+ @register_to_config
358
+ def __init__(
359
+ self,
360
+ sample_size: int = 256,
361
+ patch_size: int = 16,
362
+ in_channels: int = 3,
363
+ hidden_size: int = 768,
364
+ num_layers: int = 12,
365
+ num_attention_heads: int = 12,
366
+ mlp_ratio: float = 4.0,
367
+ attention_dropout: float = 0.0,
368
+ dropout: float = 0.0,
369
+ num_classes: int = 1000,
370
+ bottleneck_dim: int = 128,
371
+ in_context_len: int = 32,
372
+ in_context_start: int = 4,
373
+ norm_eps: float = 1e-6,
374
+ ):
375
+ super().__init__()
376
+ self.sample_size = sample_size
377
+ self.patch_size = patch_size
378
+ self.in_channels = in_channels
379
+ self.out_channels = in_channels
380
+ self.hidden_size = hidden_size
381
+ self.num_layers = num_layers
382
+ self.num_attention_heads = num_attention_heads
383
+ self.in_context_len = in_context_len
384
+ self.in_context_start = in_context_start
385
+ self.norm_eps = norm_eps
386
+ self.gradient_checkpointing = False
387
+
388
+ # Time and Class Embedding
389
+ self.t_embedder = JiTTimestepEmbedder(hidden_size)
390
+ self.y_embedder = JiTLabelEmbedder(num_classes, hidden_size)
391
+
392
+ # Patch Embedding
393
+ self.x_embedder = JiTPatchEmbed(
394
+ img_size=sample_size,
395
+ patch_size=patch_size,
396
+ in_chans=in_channels,
397
+ pca_dim=bottleneck_dim,
398
+ embed_dim=hidden_size,
399
+ bias=True,
400
+ )
401
+
402
+ # Positional Embedding (Fixed Sin-Cos)
403
+ num_patches = self.x_embedder.num_patches
404
+ pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches**0.5))
405
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
406
+
407
+ # In-context Embedding
408
+ if self.in_context_len > 0:
409
+ self.in_context_posemb = nn.Parameter(torch.zeros(1, self.in_context_len, hidden_size))
410
+
411
+ # RoPE
412
+ half_head_dim = hidden_size // num_attention_heads // 2
413
+ hw_seq_len = sample_size // patch_size
414
+ self.feat_rope = JiTRotaryEmbedding(dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=0)
415
+ self.feat_rope_incontext = JiTRotaryEmbedding(
416
+ dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=self.in_context_len
417
+ )
418
+
419
+ # Blocks
420
+ self.blocks = nn.ModuleList(
421
+ [
422
+ JiTBlock(
423
+ hidden_size,
424
+ num_attention_heads,
425
+ mlp_ratio=mlp_ratio,
426
+ attn_drop=attention_dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
427
+ proj_drop=dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
428
+ eps=norm_eps,
429
+ )
430
+ for i in range(num_layers)
431
+ ]
432
+ )
433
+
434
+ # Final Layer
435
+ self.norm_final = RMSNorm(hidden_size, eps=norm_eps)
436
+ self.linear_final = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
437
+ self.act_final = nn.SiLU()
438
+ self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
439
+
440
+ def forward(
441
+ self,
442
+ hidden_states: torch.Tensor,
443
+ timestep: torch.LongTensor,
444
+ class_labels: torch.LongTensor,
445
+ return_dict: bool = True,
446
+ ):
447
+
448
+ t_emb = self.t_embedder(timestep, dtype=hidden_states.dtype)
449
+ y_emb = self.y_embedder(class_labels)
450
+
451
+ # Ensure embeddings match hidden_states dtype
452
+ y_emb = y_emb.to(dtype=hidden_states.dtype)
453
+
454
+ c = t_emb + y_emb
455
+
456
+ # Patch Embed
457
+ x = self.x_embedder(hidden_states)
458
+ x = x + self.pos_embed.to(x.dtype)
459
+
460
+ # Blocks
461
+ for i, block in enumerate(self.blocks):
462
+ if self.in_context_len > 0 and i == self.in_context_start:
463
+ in_context_tokens = y_emb.unsqueeze(1).repeat(1, self.in_context_len, 1)
464
+ in_context_tokens = in_context_tokens + self.in_context_posemb.to(in_context_tokens.dtype)
465
+ x = torch.cat([in_context_tokens, x], dim=1)
466
+
467
+ rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
468
+
469
+ if self.training and self.gradient_checkpointing:
470
+ x = torch.utils.checkpoint.checkpoint(
471
+ block,
472
+ x,
473
+ c,
474
+ rope,
475
+ use_reentrant=False,
476
+ )
477
+ else:
478
+ x = block(x, c, feat_rope=rope)
479
+
480
+ # Slice off in-context tokens
481
+ if self.in_context_len > 0:
482
+ x = x[:, self.in_context_len :]
483
+
484
+ # Final Layer
485
+ c = self.act_final(c)
486
+ shift, scale = self.adaLN_modulation_final(c).chunk(2, dim=1)
487
+
488
+ x = modulate(self.norm_final(x), shift, scale)
489
+ x = self.linear_final(x)
490
+
491
+ # Unpatchify
492
+ h = w = int(x.shape[1] ** 0.5)
493
+ x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels))
494
+ x = torch.einsum("nhwpqc->nchpwq", x)
495
+ output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size))
496
+
497
+ if not return_dict:
498
+ return (output,)
499
+
500
+ return Transformer2DModelOutput(sample=output)
JiT-L-32/model_index.json CHANGED
@@ -1,8 +1,15 @@
1
  {
2
- "_class_name": "JiTPipeline",
 
 
 
3
  "_diffusers_version": "0.36.0",
 
 
 
 
4
  "transformer": [
5
- "jit_diffusers",
6
  "JiTTransformer2DModel"
7
  ]
8
  }
 
1
  {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "JiTPipeline"
5
+ ],
6
  "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "scheduling_jit",
9
+ "JiTScheduler"
10
+ ],
11
  "transformer": [
12
+ "jit_transformer_2d",
13
  "JiTTransformer2DModel"
14
  ]
15
  }
JiT-L-32/pipeline.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import importlib
18
+ import json
19
+ import sys
20
+ from pathlib import Path
21
+ from typing import Dict, List, Optional, Tuple, Union
22
+
23
+ import torch
24
+
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
26
+ from diffusers.utils.torch_utils import randn_tensor
27
+
28
+
29
+ RECOMMENDED_NOISE_BY_SIZE = {
30
+ 256: 1.0,
31
+ 512: 2.0,
32
+ }
33
+
34
+
35
+ class JiTPipeline(DiffusionPipeline):
36
+ r"""
37
+ Pipeline for image generation using JiT (Just image Transformer).
38
+
39
+ Parameters:
40
+ transformer ([`JiTTransformer2DModel`]):
41
+ A class-conditioned `JiTTransformer2DModel` to denoise the images.
42
+ scheduler ([`JiTScheduler`]):
43
+ Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler).
44
+ id2label (`dict[int, str]`, *optional*):
45
+ ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
46
+ id2label_cn (`dict[int, str]`, *optional*):
47
+ ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms.
48
+ """
49
+
50
+ model_cpu_offload_seq = "transformer"
51
+
52
+ @classmethod
53
+ def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
54
+ """Load a self-contained variant folder locally or from the Hub.
55
+
56
+ Examples:
57
+ JiTPipeline.from_pretrained(".")
58
+ JiTPipeline.from_pretrained("./JiT-H-32")
59
+ DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
60
+ """
61
+ repo_root = Path(__file__).resolve().parent
62
+
63
+ if pretrained_model_name_or_path in (None, "", "."):
64
+ variant = repo_root
65
+ elif (
66
+ isinstance(pretrained_model_name_or_path, str)
67
+ and "/" in pretrained_model_name_or_path
68
+ and not Path(pretrained_model_name_or_path).exists()
69
+ ):
70
+ from huggingface_hub import snapshot_download
71
+
72
+ hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
73
+ if subfolder:
74
+ hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"])
75
+ cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
76
+ variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
77
+ else:
78
+ variant = Path(pretrained_model_name_or_path)
79
+ if not variant.is_absolute():
80
+ candidate = (Path.cwd() / variant).resolve()
81
+ variant = candidate if candidate.exists() else (repo_root / variant).resolve()
82
+ if subfolder:
83
+ variant = variant / subfolder
84
+
85
+ model_kwargs = dict(kwargs)
86
+ inserted: List[str] = []
87
+
88
+ def _load_component(folder: str, module_name: str, class_name: str):
89
+ comp_dir = variant / folder
90
+ module_path = comp_dir / f"{module_name}.py"
91
+ has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
92
+ if not module_path.exists() or not has_weights:
93
+ return None
94
+
95
+ comp_path = str(comp_dir)
96
+ if comp_path not in sys.path:
97
+ sys.path.insert(0, comp_path)
98
+ inserted.append(comp_path)
99
+
100
+ module = importlib.import_module(module_name)
101
+ component_cls = getattr(module, class_name)
102
+ return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
103
+
104
+ try:
105
+ transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
106
+ scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler")
107
+
108
+ if transformer is None:
109
+ raise ValueError(f"No loadable transformer found under {variant}")
110
+
111
+ variant_path = str(variant)
112
+ id2label, id2label_cn = cls._load_labels_for_variant(variant_path)
113
+
114
+ pipe = cls(
115
+ transformer=transformer,
116
+ scheduler=scheduler,
117
+ id2label=id2label,
118
+ id2label_cn=id2label_cn,
119
+ )
120
+ if variant_path and hasattr(pipe, "register_to_config"):
121
+ pipe.register_to_config(_name_or_path=variant_path)
122
+ return pipe
123
+ finally:
124
+ for comp_path in inserted:
125
+ if comp_path in sys.path:
126
+ sys.path.remove(comp_path)
127
+
128
+ def __init__(
129
+ self,
130
+ transformer,
131
+ scheduler,
132
+ id2label: Optional[Dict[int, str]] = None,
133
+ id2label_cn: Optional[Dict[int, str]] = None,
134
+ ):
135
+ super().__init__()
136
+ self.register_modules(transformer=transformer, scheduler=scheduler)
137
+
138
+ self._id2label = id2label or {}
139
+ self._id2label_cn = id2label_cn or {}
140
+ self.labels = self._build_label2id(self._id2label)
141
+ self.labels_cn = self._build_label2id(self._id2label_cn)
142
+
143
+ def _ensure_labels_loaded(self) -> None:
144
+ if self._id2label or self._id2label_cn:
145
+ return
146
+ loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None))
147
+ if loaded_en:
148
+ self._id2label = loaded_en
149
+ self.labels = self._build_label2id(self._id2label)
150
+ if loaded_cn:
151
+ self._id2label_cn = loaded_cn
152
+ self.labels_cn = self._build_label2id(self._id2label_cn)
153
+
154
+ @staticmethod
155
+ def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]:
156
+ if not variant_path:
157
+ return None
158
+ variant_dir = Path(variant_path).resolve()
159
+ labels_dir = variant_dir.parent / "labels"
160
+ return labels_dir if labels_dir.is_dir() else None
161
+
162
+ @staticmethod
163
+ def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]:
164
+ filename = "id2label_en.json" if lang == "en" else "id2label_cn.json"
165
+ path = labels_dir / filename
166
+ if not path.exists():
167
+ raise FileNotFoundError(path)
168
+ raw = json.loads(path.read_text(encoding="utf-8"))
169
+ return {int(key): value for key, value in raw.items()}
170
+
171
+ @classmethod
172
+ def _load_labels_for_variant(
173
+ cls,
174
+ variant_path: Optional[str],
175
+ ) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]:
176
+ labels_dir = cls._labels_dir_for_variant(variant_path)
177
+ if labels_dir is None:
178
+ return None, None
179
+ try:
180
+ return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn")
181
+ except FileNotFoundError:
182
+ return None, None
183
+
184
+ @staticmethod
185
+ def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
186
+ label2id: Dict[str, int] = {}
187
+ for class_id, value in id2label.items():
188
+ for synonym in value.split(","):
189
+ synonym = synonym.strip()
190
+ if synonym:
191
+ label2id[synonym] = int(class_id)
192
+ return dict(sorted(label2id.items()))
193
+
194
+ @property
195
+ def id2label(self) -> Dict[int, str]:
196
+ """ImageNet class id to English label string (comma-separated synonyms)."""
197
+ self._ensure_labels_loaded()
198
+ return self._id2label
199
+
200
+ @property
201
+ def id2label_cn(self) -> Dict[int, str]:
202
+ """ImageNet class id to Chinese label string (comma-separated synonyms)."""
203
+ self._ensure_labels_loaded()
204
+ return self._id2label_cn
205
+
206
+ def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]:
207
+ r"""
208
+ Map ImageNet label strings to class ids.
209
+
210
+ Args:
211
+ label (`str` or `list[str]`):
212
+ One or more label strings. Each string must match a synonym in `id2label` (English)
213
+ or `id2label_cn` (Chinese).
214
+ lang (`str`, *optional*, defaults to `"en"`):
215
+ `"en"` uses English synonyms; `"cn"` uses Chinese synonyms.
216
+ """
217
+ if lang not in ("en", "cn"):
218
+ raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.")
219
+
220
+ self._ensure_labels_loaded()
221
+ label2id = self.labels if lang == "en" else self.labels_cn
222
+ if not label2id:
223
+ raise ValueError(
224
+ f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
225
+ )
226
+
227
+ if isinstance(label, str):
228
+ label = [label]
229
+
230
+ missing = [item for item in label if item not in label2id]
231
+ if missing:
232
+ preview = ", ".join(list(label2id.keys())[:8])
233
+ raise ValueError(
234
+ f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
235
+ )
236
+ return [label2id[item] for item in label]
237
+
238
+ def _normalize_class_labels(
239
+ self,
240
+ class_labels: Union[int, str, List[Union[int, str]]],
241
+ ) -> List[int]:
242
+ if isinstance(class_labels, int):
243
+ return [class_labels]
244
+
245
+ if isinstance(class_labels, str):
246
+ return self.get_label_ids(class_labels)
247
+
248
+ if class_labels and isinstance(class_labels[0], str):
249
+ self._ensure_labels_loaded()
250
+ if all(label in self.labels for label in class_labels):
251
+ return self.get_label_ids(class_labels, lang="en")
252
+ if all(label in self.labels_cn for label in class_labels):
253
+ return self.get_label_ids(class_labels, lang="cn")
254
+ raise ValueError(
255
+ "Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` "
256
+ "or Chinese synonyms from `pipe.labels_cn`."
257
+ )
258
+
259
+ return list(class_labels)
260
+
261
+ def _predict_velocity(
262
+ self,
263
+ z_value: torch.Tensor,
264
+ t: torch.Tensor,
265
+ class_labels: torch.Tensor,
266
+ class_null: torch.Tensor,
267
+ do_classifier_free_guidance: bool,
268
+ guidance_scale: float,
269
+ guidance_interval_min: float,
270
+ guidance_interval_max: float,
271
+ ) -> torch.Tensor:
272
+ t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype)
273
+ if do_classifier_free_guidance:
274
+ z_in = torch.cat([z_value, z_value], dim=0)
275
+ labels = torch.cat([class_labels, class_null], dim=0)
276
+ else:
277
+ z_in = z_value
278
+ labels = class_labels
279
+
280
+ t_batch = t.flatten().expand(z_in.shape[0])
281
+ x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample
282
+ v = self.scheduler.velocity_from_prediction(z_in, x_pred, t)
283
+
284
+ if not do_classifier_free_guidance:
285
+ return v
286
+
287
+ v_cond, v_uncond = v.chunk(2, dim=0)
288
+ interval_mask = t < guidance_interval_max
289
+ if guidance_interval_min != 0.0:
290
+ interval_mask = interval_mask & (t > guidance_interval_min)
291
+ scale = torch.where(
292
+ interval_mask,
293
+ torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype),
294
+ torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype),
295
+ )
296
+ return v_uncond + scale * (v_cond - v_uncond)
297
+
298
+ def _run_sampler(
299
+ self,
300
+ latents: torch.Tensor,
301
+ class_labels: torch.Tensor,
302
+ class_null: torch.Tensor,
303
+ num_inference_steps: int,
304
+ do_classifier_free_guidance: bool,
305
+ guidance_scale: float,
306
+ guidance_interval_min: float,
307
+ guidance_interval_max: float,
308
+ sampling_method: str,
309
+ ) -> torch.Tensor:
310
+ device = latents.device
311
+ self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method)
312
+ timesteps = self.scheduler.timesteps
313
+
314
+ for i in self.progress_bar(range(num_inference_steps - 1)):
315
+ t = timesteps[i]
316
+ t_next = timesteps[i + 1]
317
+ v = self._predict_velocity(
318
+ latents,
319
+ t,
320
+ class_labels,
321
+ class_null,
322
+ do_classifier_free_guidance,
323
+ guidance_scale,
324
+ guidance_interval_min,
325
+ guidance_interval_max,
326
+ )
327
+
328
+ if sampling_method == "heun":
329
+ latents_euler = latents + (t_next - t) * v
330
+ v_next = self._predict_velocity(
331
+ latents_euler,
332
+ t_next,
333
+ class_labels,
334
+ class_null,
335
+ do_classifier_free_guidance,
336
+ guidance_scale,
337
+ guidance_interval_min,
338
+ guidance_interval_max,
339
+ )
340
+ latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample
341
+ else:
342
+ latents = self.scheduler.step(v, t, latents).prev_sample
343
+
344
+ t = timesteps[-2]
345
+ t_next = timesteps[-1]
346
+ v = self._predict_velocity(
347
+ latents,
348
+ t,
349
+ class_labels,
350
+ class_null,
351
+ do_classifier_free_guidance,
352
+ guidance_scale,
353
+ guidance_interval_min,
354
+ guidance_interval_max,
355
+ )
356
+ return latents + (t_next - t) * v
357
+
358
+ @torch.inference_mode()
359
+ def __call__(
360
+ self,
361
+ class_labels: Union[int, str, List[Union[int, str]]],
362
+ guidance_scale: Optional[float] = None,
363
+ guidance_interval_min: float = 0.1,
364
+ guidance_interval_max: float = 1.0,
365
+ noise_scale: Optional[float] = None,
366
+ t_eps: Optional[float] = None,
367
+ sampling_method: Optional[str] = None,
368
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
369
+ num_inference_steps: int = 50,
370
+ output_type: Optional[str] = "pil",
371
+ return_dict: bool = True,
372
+ ) -> Union[ImagePipelineOutput, Tuple]:
373
+ r"""
374
+ Generate class-conditional images.
375
+
376
+ Args:
377
+ class_labels (`int`, `str`, `list[int]`, or `list[str]`):
378
+ ImageNet class indices or human-readable label strings (English or Chinese).
379
+ guidance_scale (`float`, *optional*):
380
+ Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
381
+ guidance_interval_min (`float`, defaults to `0.1`):
382
+ Lower bound of the CFG interval in flow time `t in [0, 1]`.
383
+ guidance_interval_max (`float`, defaults to `1.0`):
384
+ Upper bound of the CFG interval in flow time.
385
+ noise_scale (`float`, *optional*):
386
+ Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
387
+ t_eps (`float`, *optional*):
388
+ Epsilon clamp for the `1 - t` denominator (scheduler config by default).
389
+ sampling_method (`str`, *optional*):
390
+ `"heun"` or `"euler"`. Defaults to the scheduler config (`heun`).
391
+ generator (`torch.Generator`, *optional*):
392
+ RNG for reproducibility.
393
+ num_inference_steps (`int`, defaults to `50`):
394
+ Number of solver steps (at least 2).
395
+ output_type (`str`, *optional*, defaults to `"pil"`):
396
+ `"pil"`, `"np"`, or `"pt"`.
397
+ return_dict (`bool`, *optional*, defaults to `True`):
398
+ Return [`ImagePipelineOutput`] if True.
399
+ """
400
+ solver = sampling_method or self.scheduler.config.solver
401
+ if solver not in {"heun", "euler"}:
402
+ raise ValueError("sampling_method must be one of: 'heun', 'euler'.")
403
+ if num_inference_steps < 2:
404
+ raise ValueError("num_inference_steps must be >= 2.")
405
+
406
+ if t_eps is not None:
407
+ self.scheduler.register_to_config(t_eps=t_eps)
408
+
409
+ class_label_ids = self._normalize_class_labels(class_labels)
410
+ do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
411
+
412
+ batch_size = len(class_label_ids)
413
+ image_size = int(self.transformer.config.sample_size)
414
+ channels = int(self.transformer.config.in_channels)
415
+ null_class_val = int(self.transformer.config.num_classes)
416
+
417
+ if guidance_scale is None:
418
+ guidance_scale = 1.0
419
+ if noise_scale is None:
420
+ noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0)
421
+
422
+ latents = (
423
+ randn_tensor(
424
+ shape=(batch_size, channels, image_size, image_size),
425
+ generator=generator,
426
+ device=self._execution_device,
427
+ dtype=self.transformer.dtype,
428
+ )
429
+ * noise_scale
430
+ )
431
+
432
+ class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
433
+ class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
434
+ class_null = torch.full_like(class_labels_t, null_class_val)
435
+
436
+ latents = self._run_sampler(
437
+ latents,
438
+ class_labels_t,
439
+ class_null,
440
+ num_inference_steps,
441
+ do_classifier_free_guidance,
442
+ guidance_scale,
443
+ guidance_interval_min,
444
+ guidance_interval_max,
445
+ solver,
446
+ )
447
+
448
+ images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
449
+ if output_type == "pt":
450
+ images = images_pt
451
+ elif output_type == "np":
452
+ images = images_pt.permute(0, 2, 3, 1).numpy()
453
+ else:
454
+ images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy())
455
+
456
+ self.maybe_free_model_hooks()
457
+
458
+ if not return_dict:
459
+ return (images,)
460
+ return ImagePipelineOutput(images=images)
JiT-L-32/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "JiTScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "num_train_timesteps": 1000,
5
+ "t_eps": 0.05,
6
+ "solver": "heun"
7
+ }
JiT-L-32/scheduler/scheduling_jit.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
22
+ from diffusers.utils import BaseOutput
23
+
24
+
25
+ @dataclass
26
+ class JiTSchedulerOutput(BaseOutput):
27
+ """
28
+ Output class for the JiT scheduler's `step` function.
29
+
30
+ Args:
31
+ prev_sample (`torch.Tensor`):
32
+ Updated sample after one solver step along the JiT flow-time grid.
33
+ """
34
+
35
+ prev_sample: torch.Tensor
36
+
37
+
38
+ class JiTScheduler(SchedulerMixin, ConfigMixin):
39
+ """
40
+ Manual flow-matching scheduler for JiT checkpoints.
41
+
42
+ Uses a linear flow-time grid `t in [0, 1]` (increasing), matching the official JiT
43
+ sampler. Velocity is `v = (x_pred - z) / (1 - t)`; integration is explicit Euler or
44
+ Heun along that grid.
45
+ """
46
+
47
+ order = 2
48
+
49
+ @register_to_config
50
+ def __init__(
51
+ self,
52
+ num_train_timesteps: int = 1000,
53
+ t_eps: float = 5e-2,
54
+ solver: str = "heun",
55
+ ):
56
+ if solver not in {"heun", "euler"}:
57
+ raise ValueError("solver must be one of: 'heun', 'euler'.")
58
+ self.timesteps: Optional[torch.Tensor] = None
59
+ self.sigmas: Optional[List[float]] = None
60
+ self.num_inference_steps: Optional[int] = None
61
+ self._step_index: Optional[int] = None
62
+
63
+ @property
64
+ def init_noise_sigma(self) -> float:
65
+ return 1.0
66
+
67
+ def set_timesteps(
68
+ self,
69
+ num_inference_steps: int,
70
+ device: Union[str, torch.device, None] = None,
71
+ solver: Optional[str] = None,
72
+ ) -> None:
73
+ if num_inference_steps < 2:
74
+ raise ValueError("num_inference_steps must be >= 2.")
75
+
76
+ self.num_inference_steps = num_inference_steps
77
+ self.timesteps = torch.linspace(
78
+ 0.0,
79
+ 1.0,
80
+ num_inference_steps + 1,
81
+ device=device,
82
+ dtype=torch.float32,
83
+ )
84
+ sigma_grid = torch.linspace(0.0, 1.0, num_inference_steps, device=device, dtype=torch.float32)
85
+ self.sigmas = (1.0 - sigma_grid).tolist()
86
+ self._step_index = 0
87
+ if solver is not None:
88
+ self.register_to_config(solver=solver)
89
+
90
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
91
+ del timestep
92
+ return sample
93
+
94
+ def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int:
95
+ if self._step_index is not None:
96
+ return self._step_index
97
+ if self.timesteps is None:
98
+ raise ValueError("Call `set_timesteps` before `step`.")
99
+ if timestep is None:
100
+ return 0
101
+ t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0])
102
+ matches = (self.timesteps - t_value).abs() < 1e-6
103
+ if matches.any():
104
+ return int(matches.nonzero(as_tuple=False)[0].item())
105
+ return 0
106
+
107
+ def step(
108
+ self,
109
+ model_output: torch.Tensor,
110
+ timestep: Union[float, torch.Tensor, None],
111
+ sample: torch.Tensor,
112
+ model_output_next: Optional[torch.Tensor] = None,
113
+ return_dict: bool = True,
114
+ ) -> Union[JiTSchedulerOutput, Tuple[torch.Tensor]]:
115
+ """
116
+ Integrate one step on the linear `t` grid.
117
+
118
+ Args:
119
+ model_output (`torch.Tensor`):
120
+ Velocity `v = (x_pred - z) / (1 - t)` at the current time.
121
+ timestep (`float` or `torch.Tensor`, *optional*):
122
+ Current flow time `t`. When omitted, uses the internal step index.
123
+ sample (`torch.Tensor`):
124
+ Current noisy latent `z`.
125
+ model_output_next (`torch.Tensor`, *optional*):
126
+ Velocity at `t_next` (required for Heun intermediate steps).
127
+ """
128
+ if self.timesteps is None:
129
+ raise ValueError("Call `set_timesteps` before `step`.")
130
+
131
+ step_index = self._resolve_step_index(timestep)
132
+ if step_index >= len(self.timesteps) - 1:
133
+ raise ValueError("Scheduler has already reached the final timestep.")
134
+
135
+ t = self.timesteps[step_index]
136
+ t_next = self.timesteps[step_index + 1]
137
+ dt = t_next - t
138
+
139
+ if self.config.solver == "heun" and model_output_next is not None:
140
+ prev_sample = sample + dt * 0.5 * (model_output + model_output_next)
141
+ else:
142
+ prev_sample = sample + dt * model_output
143
+
144
+ self._step_index = step_index + 1
145
+
146
+ if not return_dict:
147
+ return (prev_sample,)
148
+ return JiTSchedulerOutput(prev_sample=prev_sample)
149
+
150
+ def velocity_from_prediction(
151
+ self,
152
+ sample: torch.Tensor,
153
+ x_pred: torch.Tensor,
154
+ timestep: Union[float, torch.Tensor],
155
+ ) -> torch.Tensor:
156
+ """Compute JiT velocity `v = (x_pred - z) / (1 - t)` with denominator clamp."""
157
+ t = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype)
158
+ while t.ndim < sample.ndim:
159
+ t = t.unsqueeze(-1)
160
+ denom = (1.0 - t).clamp_min(self.config.t_eps)
161
+ return (x_pred - sample) / denom
JiT-L-32/transformer/config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "JiTTransformer2DModel",
3
+ "_diffusers_version": "0.36.0",
4
+ "attention_dropout": 0.0,
5
+ "bottleneck_dim": 128,
6
+ "dropout": 0.0,
7
+ "hidden_size": 1024,
8
+ "in_channels": 3,
9
+ "in_context_len": 32,
10
+ "in_context_start": 8,
11
+ "mlp_ratio": 4.0,
12
+ "norm_eps": 1e-06,
13
+ "num_attention_heads": 16,
14
+ "num_classes": 1000,
15
+ "num_layers": 24,
16
+ "patch_size": 32,
17
+ "sample_size": 512
18
+ }
JiT-L-32/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:121d3917ab50ad034295646734eb9b898167f19419dd65d22946f38c7d183266
3
+ size 1847219704
JiT-L-32/transformer/jit_transformer_2d.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
24
+ from diffusers.models.modeling_utils import ModelMixin
25
+ from diffusers.models.normalization import RMSNorm
26
+ from diffusers.utils import logging
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+
32
+ def broadcat(tensors, dim=-1):
33
+ num_tensors = len(tensors)
34
+ shape_lens = {len(t.shape) for t in tensors}
35
+ if len(shape_lens) != 1:
36
+ raise ValueError("tensors must all have the same number of dimensions")
37
+ shape_len = list(shape_lens)[0]
38
+ dim = (dim + shape_len) if dim < 0 else dim
39
+ dims = list(zip(*(list(t.shape) for t in tensors)))
40
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
41
+
42
+ if not all(len(set(t[1])) <= 2 for t in expandable_dims):
43
+ raise ValueError("invalid dimensions for broadcastable concatenation")
44
+
45
+ max_dims = [(t[0], max(t[1])) for t in expandable_dims]
46
+ expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
47
+ expanded_dims.insert(dim, (dim, dims[dim]))
48
+ expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
49
+ tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
50
+ return torch.cat(tensors, dim=dim)
51
+
52
+
53
+ def rotate_half(x):
54
+ x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2)
55
+ x1, x2 = x.unbind(dim=-1)
56
+ x = torch.stack((-x2, x1), dim=-1)
57
+ return x.view(*x.shape[:-2], -1)
58
+
59
+
60
+ class JiTRotaryEmbedding(nn.Module):
61
+ def __init__(
62
+ self,
63
+ dim,
64
+ pt_seq_len=16,
65
+ ft_seq_len=None,
66
+ custom_freqs=None,
67
+ theta=10000,
68
+ num_cls_token=0,
69
+ ):
70
+ super().__init__()
71
+ if custom_freqs is not None:
72
+ freqs = custom_freqs
73
+ else:
74
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
75
+
76
+ if ft_seq_len is None:
77
+ ft_seq_len = pt_seq_len
78
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
79
+
80
+ freqs = torch.einsum("..., f -> ... f", t, freqs)
81
+ freqs = freqs.repeat_interleave(2, dim=-1)
82
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
83
+
84
+ if num_cls_token > 0:
85
+ freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D]
86
+ cos_img = freqs_flat.cos()
87
+ sin_img = freqs_flat.sin()
88
+
89
+ # prepend in-context cls token
90
+ _, D = cos_img.shape
91
+ cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype)
92
+ sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype)
93
+
94
+ self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False)
95
+ self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False)
96
+ else:
97
+ self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False)
98
+ self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False)
99
+
100
+ def forward(self, t):
101
+ # Applied on (batch, seq_len, heads, head_dim) tensors from attention.
102
+ seq_len = t.shape[1]
103
+ freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
104
+ freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
105
+
106
+ return t * freqs_cos[:, None, :] + rotate_half(t) * freqs_sin[:, None, :]
107
+
108
+
109
+ def modulate(x, shift, scale):
110
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
111
+
112
+
113
+ class JiTPatchEmbed(nn.Module):
114
+ """Image to Patch Embedding with Bottleneck"""
115
+
116
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True):
117
+ super().__init__()
118
+ img_size = (img_size, img_size)
119
+ patch_size = (patch_size, patch_size)
120
+ self.img_size = img_size
121
+ self.patch_size = patch_size
122
+ self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
123
+
124
+ self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False)
125
+ self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias)
126
+
127
+ def forward(self, x):
128
+ x = self.proj2(self.proj1(x)).flatten(2).transpose(1, 2)
129
+ return x
130
+
131
+
132
+ class JiTTimestepEmbedder(nn.Module):
133
+ """
134
+ Embeds scalar timesteps into vector representations.
135
+ """
136
+
137
+ def __init__(self, hidden_size, frequency_embedding_size=256):
138
+ super().__init__()
139
+ self.mlp = nn.Sequential(
140
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
141
+ nn.SiLU(),
142
+ nn.Linear(hidden_size, hidden_size, bias=True),
143
+ )
144
+ self.frequency_embedding_size = frequency_embedding_size
145
+
146
+ @staticmethod
147
+ def timestep_embedding(t, dim, max_period=10000):
148
+ """
149
+ Create sinusoidal timestep embeddings.
150
+ """
151
+ half = dim // 2
152
+ freqs = torch.exp(
153
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
154
+ ).to(device=t.device)
155
+ args = t[:, None].float() * freqs[None]
156
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
157
+ if dim % 2:
158
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
159
+ return embedding
160
+
161
+ def forward(self, t, dtype=None):
162
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
163
+ if dtype is not None:
164
+ t_freq = t_freq.to(dtype=dtype)
165
+ t_emb = self.mlp(t_freq)
166
+ return t_emb
167
+
168
+
169
+ class JiTLabelEmbedder(nn.Module):
170
+ """
171
+ Embeds class labels into vector representations.
172
+ """
173
+
174
+ def __init__(self, num_classes, hidden_size):
175
+ super().__init__()
176
+ self.embedding_table = nn.Embedding(num_classes + 1, hidden_size)
177
+ self.num_classes = num_classes
178
+
179
+ def forward(self, labels):
180
+ embeddings = self.embedding_table(labels)
181
+ return embeddings
182
+
183
+
184
+ class JiTAttention(nn.Module):
185
+ def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
186
+ super().__init__()
187
+ self.num_heads = num_heads
188
+ head_dim = dim // num_heads
189
+
190
+ self.q_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
191
+ self.k_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
192
+
193
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
194
+ self.attn_drop = attn_drop
195
+ self.proj = nn.Linear(dim, dim)
196
+ self.proj_drop = nn.Dropout(proj_drop)
197
+
198
+ def forward(self, x, rope=None):
199
+ B, N, C = x.shape
200
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
201
+ q, k, v = qkv[0], qkv[1], qkv[2]
202
+
203
+ q = self.q_norm(q)
204
+ k = self.k_norm(k)
205
+
206
+ if rope is not None:
207
+ q = q.transpose(1, 2)
208
+ k = k.transpose(1, 2)
209
+ q = rope(q)
210
+ k = rope(k)
211
+ q = q.transpose(1, 2)
212
+ k = k.transpose(1, 2)
213
+
214
+ dropout_p = self.attn_drop if self.training else 0.0
215
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
216
+ x = x.transpose(1, 2).reshape(B, N, C)
217
+ x = self.proj(x)
218
+ x = self.proj_drop(x)
219
+ return x
220
+
221
+
222
+ class JiTSwiGLUFFN(nn.Module):
223
+ def __init__(self, dim: int, hidden_dim: int, drop=0.0, bias=True) -> None:
224
+ super().__init__()
225
+ hidden_dim = int(hidden_dim * 2 / 3)
226
+ self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias)
227
+ self.w3 = nn.Linear(hidden_dim, dim, bias=bias)
228
+ self.ffn_dropout = nn.Dropout(drop)
229
+
230
+ def forward(self, x):
231
+ x12 = self.w12(x)
232
+ x1, x2 = x12.chunk(2, dim=-1)
233
+ hidden = F.silu(x1) * x2
234
+ return self.w3(self.ffn_dropout(hidden))
235
+
236
+
237
+ class JiTBlock(nn.Module):
238
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
239
+ super().__init__()
240
+ self.norm1 = RMSNorm(hidden_size, eps=eps)
241
+ self.attn = JiTAttention(
242
+ hidden_size,
243
+ num_heads=num_heads,
244
+ qkv_bias=True,
245
+ qk_norm=True,
246
+ attn_drop=attn_drop,
247
+ proj_drop=proj_drop,
248
+ eps=eps,
249
+ )
250
+ self.norm2 = RMSNorm(hidden_size, eps=eps)
251
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
252
+ self.mlp = JiTSwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop)
253
+
254
+ self.act = nn.SiLU()
255
+ self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
256
+
257
+ def forward(self, x, c, feat_rope=None):
258
+ # Apply activation
259
+ c = self.act(c)
260
+
261
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
262
+
263
+ # Attention block
264
+ norm_x = self.norm1(x)
265
+ modulated_x = modulate(norm_x, shift_msa, scale_msa)
266
+ attn_out = self.attn(modulated_x, rope=feat_rope)
267
+ x = x + gate_msa.unsqueeze(1) * attn_out
268
+
269
+ # MLP block
270
+ norm_x = self.norm2(x)
271
+ modulated_x = modulate(norm_x, shift_mlp, scale_mlp)
272
+ mlp_out = self.mlp(modulated_x)
273
+ x = x + gate_mlp.unsqueeze(1) * mlp_out
274
+
275
+ return x
276
+
277
+
278
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
279
+ grid_h = np.arange(grid_size, dtype=np.float32)
280
+ grid_w = np.arange(grid_size, dtype=np.float32)
281
+ grid = np.meshgrid(grid_w, grid_h)
282
+ grid = np.stack(grid, axis=0)
283
+ grid = grid.reshape([2, 1, grid_size, grid_size])
284
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
285
+ if cls_token and extra_tokens > 0:
286
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
287
+ return pos_embed
288
+
289
+
290
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
291
+ if embed_dim % 2 != 0:
292
+ raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
293
+
294
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
295
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
296
+ emb = np.concatenate([emb_h, emb_w], axis=1)
297
+ return emb
298
+
299
+
300
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
301
+ if embed_dim % 2 != 0:
302
+ raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
303
+
304
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
305
+ omega /= embed_dim / 2.0
306
+ omega = 1.0 / 10000**omega
307
+
308
+ pos = pos.reshape(-1)
309
+ out = np.einsum("m,d->md", pos, omega)
310
+
311
+ emb_sin = np.sin(out)
312
+ emb_cos = np.cos(out)
313
+
314
+ emb = np.concatenate([emb_sin, emb_cos], axis=1)
315
+ return emb
316
+
317
+
318
+ class JiTTransformer2DModel(ModelMixin, ConfigMixin):
319
+ r"""
320
+ A 2D Transformer for pixel-space class-conditional generation with JiT
321
+ ([Back to Basics: Let Denoising Generative Models Denoise](https://arxiv.org/abs/2511.13720)).
322
+
323
+ Parameters:
324
+ sample_size (`int`, defaults to `256`):
325
+ Input image resolution (height and width).
326
+ patch_size (`int`, defaults to `16`):
327
+ Patch size for the bottleneck patch embedder.
328
+ in_channels (`int`, defaults to `3`):
329
+ Number of input image channels.
330
+ hidden_size (`int`, defaults to `768`):
331
+ Transformer hidden dimension.
332
+ num_layers (`int`, defaults to `12`):
333
+ Number of JiT transformer blocks.
334
+ num_attention_heads (`int`, defaults to `12`):
335
+ Number of attention heads per block.
336
+ mlp_ratio (`float`, defaults to `4.0`):
337
+ MLP hidden dimension multiplier.
338
+ attention_dropout (`float`, defaults to `0.0`):
339
+ Attention dropout in the middle quarter of blocks.
340
+ dropout (`float`, defaults to `0.0`):
341
+ Projection dropout in the middle quarter of blocks.
342
+ num_classes (`int`, defaults to `1000`):
343
+ Number of class labels (null label uses index `num_classes` for CFG).
344
+ bottleneck_dim (`int`, defaults to `128`):
345
+ PCA bottleneck dimension in the patch embedder.
346
+ in_context_len (`int`, defaults to `32`):
347
+ Number of in-context class tokens prepended mid-network.
348
+ in_context_start (`int`, defaults to `4`):
349
+ Block index at which in-context tokens are inserted.
350
+ norm_eps (`float`, defaults to `1e-6`):
351
+ Epsilon for RMSNorm layers.
352
+ """
353
+
354
+ _supports_gradient_checkpointing = True
355
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
356
+
357
+ @register_to_config
358
+ def __init__(
359
+ self,
360
+ sample_size: int = 256,
361
+ patch_size: int = 16,
362
+ in_channels: int = 3,
363
+ hidden_size: int = 768,
364
+ num_layers: int = 12,
365
+ num_attention_heads: int = 12,
366
+ mlp_ratio: float = 4.0,
367
+ attention_dropout: float = 0.0,
368
+ dropout: float = 0.0,
369
+ num_classes: int = 1000,
370
+ bottleneck_dim: int = 128,
371
+ in_context_len: int = 32,
372
+ in_context_start: int = 4,
373
+ norm_eps: float = 1e-6,
374
+ ):
375
+ super().__init__()
376
+ self.sample_size = sample_size
377
+ self.patch_size = patch_size
378
+ self.in_channels = in_channels
379
+ self.out_channels = in_channels
380
+ self.hidden_size = hidden_size
381
+ self.num_layers = num_layers
382
+ self.num_attention_heads = num_attention_heads
383
+ self.in_context_len = in_context_len
384
+ self.in_context_start = in_context_start
385
+ self.norm_eps = norm_eps
386
+ self.gradient_checkpointing = False
387
+
388
+ # Time and Class Embedding
389
+ self.t_embedder = JiTTimestepEmbedder(hidden_size)
390
+ self.y_embedder = JiTLabelEmbedder(num_classes, hidden_size)
391
+
392
+ # Patch Embedding
393
+ self.x_embedder = JiTPatchEmbed(
394
+ img_size=sample_size,
395
+ patch_size=patch_size,
396
+ in_chans=in_channels,
397
+ pca_dim=bottleneck_dim,
398
+ embed_dim=hidden_size,
399
+ bias=True,
400
+ )
401
+
402
+ # Positional Embedding (Fixed Sin-Cos)
403
+ num_patches = self.x_embedder.num_patches
404
+ pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches**0.5))
405
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
406
+
407
+ # In-context Embedding
408
+ if self.in_context_len > 0:
409
+ self.in_context_posemb = nn.Parameter(torch.zeros(1, self.in_context_len, hidden_size))
410
+
411
+ # RoPE
412
+ half_head_dim = hidden_size // num_attention_heads // 2
413
+ hw_seq_len = sample_size // patch_size
414
+ self.feat_rope = JiTRotaryEmbedding(dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=0)
415
+ self.feat_rope_incontext = JiTRotaryEmbedding(
416
+ dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=self.in_context_len
417
+ )
418
+
419
+ # Blocks
420
+ self.blocks = nn.ModuleList(
421
+ [
422
+ JiTBlock(
423
+ hidden_size,
424
+ num_attention_heads,
425
+ mlp_ratio=mlp_ratio,
426
+ attn_drop=attention_dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
427
+ proj_drop=dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
428
+ eps=norm_eps,
429
+ )
430
+ for i in range(num_layers)
431
+ ]
432
+ )
433
+
434
+ # Final Layer
435
+ self.norm_final = RMSNorm(hidden_size, eps=norm_eps)
436
+ self.linear_final = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
437
+ self.act_final = nn.SiLU()
438
+ self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
439
+
440
+ def forward(
441
+ self,
442
+ hidden_states: torch.Tensor,
443
+ timestep: torch.LongTensor,
444
+ class_labels: torch.LongTensor,
445
+ return_dict: bool = True,
446
+ ):
447
+
448
+ t_emb = self.t_embedder(timestep, dtype=hidden_states.dtype)
449
+ y_emb = self.y_embedder(class_labels)
450
+
451
+ # Ensure embeddings match hidden_states dtype
452
+ y_emb = y_emb.to(dtype=hidden_states.dtype)
453
+
454
+ c = t_emb + y_emb
455
+
456
+ # Patch Embed
457
+ x = self.x_embedder(hidden_states)
458
+ x = x + self.pos_embed.to(x.dtype)
459
+
460
+ # Blocks
461
+ for i, block in enumerate(self.blocks):
462
+ if self.in_context_len > 0 and i == self.in_context_start:
463
+ in_context_tokens = y_emb.unsqueeze(1).repeat(1, self.in_context_len, 1)
464
+ in_context_tokens = in_context_tokens + self.in_context_posemb.to(in_context_tokens.dtype)
465
+ x = torch.cat([in_context_tokens, x], dim=1)
466
+
467
+ rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
468
+
469
+ if self.training and self.gradient_checkpointing:
470
+ x = torch.utils.checkpoint.checkpoint(
471
+ block,
472
+ x,
473
+ c,
474
+ rope,
475
+ use_reentrant=False,
476
+ )
477
+ else:
478
+ x = block(x, c, feat_rope=rope)
479
+
480
+ # Slice off in-context tokens
481
+ if self.in_context_len > 0:
482
+ x = x[:, self.in_context_len :]
483
+
484
+ # Final Layer
485
+ c = self.act_final(c)
486
+ shift, scale = self.adaLN_modulation_final(c).chunk(2, dim=1)
487
+
488
+ x = modulate(self.norm_final(x), shift, scale)
489
+ x = self.linear_final(x)
490
+
491
+ # Unpatchify
492
+ h = w = int(x.shape[1] ** 0.5)
493
+ x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels))
494
+ x = torch.einsum("nhwpqc->nchpwq", x)
495
+ output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size))
496
+
497
+ if not return_dict:
498
+ return (output,)
499
+
500
+ return Transformer2DModelOutput(sample=output)
README.md CHANGED
@@ -14,77 +14,67 @@ language:
14
  - en
15
  ---
16
 
17
- # JiT-H/32 (Diffusers)
18
 
19
- This repository is self-contained: model weights and a custom `diffusers` pipeline (`JiTPipeline`) are both included, so no external code repo is required.
20
 
21
- ## Available Checkpoints (All 6)
 
 
22
 
23
- The JiT paper reports six ImageNet checkpoints across 256 and 512 resolutions. Use the following relative paths with `JiTPipeline.from_pretrained(...)`.
24
 
25
- | Checkpoint | Relative path | Resolution | Pre-trained dataset | Recommended CFG | Recommended interval | Recommended noise_scale | FID-50K |
26
- |---|---|---|---|---:|---|---:|---:|
27
- | JiT-B/16 | `./JiT-B-16` | 256x256 | ImageNet 256x256 | 3.0 | `[0.1, 1.0]` | 1.0 | 3.66 |
28
- | JiT-L/16 | `./JiT-L-16` | 256x256 | ImageNet 256x256 | 2.4 | `[0.1, 1.0]` | 1.0 | 2.36 |
29
- | JiT-H/16 | `./JiT-H-16` | 256x256 | ImageNet 256x256 | 2.2 | `[0.1, 1.0]` | 1.0 | 1.86 |
30
- | JiT-B/32 | `./JiT-B-32` | 512x512 | ImageNet 512x512 | 3.0 | `[0.1, 1.0]` | 2.0 | 4.02 |
31
- | JiT-L/32 | `./JiT-L-32` | 512x512 | ImageNet 512x512 | 2.5 | `[0.1, 1.0]` | 2.0 | 2.53 |
32
- | JiT-H/32 | `./JiT-H-32` | 512x512 | ImageNet 512x512 | 2.3 | `[0.1, 1.0]` | 2.0 | 1.94 |
33
 
34
- Source: [Back to Basics: Let Denoising Generative Models Denoise (arXiv:2511.13720)](https://arxiv.org/html/2511.13720).
35
 
36
- ## Demo Image
 
 
 
 
 
 
 
37
 
38
- ![JiT-H/32 test inference](demo_images/jit_h32_test_inference.png)
39
 
40
- ## One-Stop Diffusers Inference
 
 
 
 
 
 
 
 
 
 
41
 
42
  ```python
43
- from pathlib import Path
44
- import sys
45
  import torch
46
 
47
- repo_dir = Path(".").resolve()
48
- sys.path.insert(0, str(repo_dir))
49
- from jit_diffusers import JiTPipeline
 
 
 
50
 
51
- device = "cuda" if torch.cuda.is_available() else "cpu"
52
- pipe = JiTPipeline.from_pretrained("./JiT-H-32").to(device)
53
- pipe.transformer = pipe.transformer.to(device=device, dtype=torch.bfloat16 if device == "cuda" else torch.float32)
54
- pipe.transformer.eval()
55
 
56
- generator = torch.Generator(device=device).manual_seed(42)
57
- output = pipe(
58
- class_labels=[207],
59
  num_inference_steps=50,
60
  guidance_scale=2.3,
61
- guidance_interval_min=0.1,
62
- guidance_interval_max=1.0,
63
- noise_scale=2.0,
64
- t_eps=5e-2,
65
  sampling_method="heun",
66
  generator=generator,
67
- output_type="pil",
68
- )
69
- image = output.images[0]
70
- output_path = Path("./demo_images/jit_h32_test_inference.png")
71
- output_path.parent.mkdir(parents=True, exist_ok=True)
72
- image.save(output_path)
73
- print(f"Saved image to: {output_path}")
74
  ```
75
 
76
- ## Ready-to-Run Commands (All 6 Checkpoints)
77
-
78
- Run these from this repository root (`models/BiliSakura/JiT-diffusers`).
79
-
80
- ```bash
81
- # 256x256 checkpoints
82
- python run_jit_diffusers_inference.py --model_path ./JiT-B-16 --output_path ./demo_images/jit_b16_test_inference.png --class_label 207 --seed 42 --steps 50 --cfg 3.0 --interval_min 0.1 --interval_max 1.0 --noise_scale 1.0 --t_eps 5e-2 --solver heun
83
- python run_jit_diffusers_inference.py --model_path ./JiT-L-16 --output_path ./demo_images/jit_l16_test_inference.png --class_label 207 --seed 42 --steps 50 --cfg 2.4 --interval_min 0.1 --interval_max 1.0 --noise_scale 1.0 --t_eps 5e-2 --solver heun
84
- python run_jit_diffusers_inference.py --model_path ./JiT-H-16 --output_path ./demo_images/jit_h16_test_inference.png --class_label 207 --seed 42 --steps 50 --cfg 2.2 --interval_min 0.1 --interval_max 1.0 --noise_scale 1.0 --t_eps 5e-2 --solver heun
85
-
86
- # 512x512 checkpoints
87
- python run_jit_diffusers_inference.py --model_path ./JiT-B-32 --output_path ./demo_images/jit_b32_test_inference.png --class_label 207 --seed 42 --steps 50 --cfg 3.0 --interval_min 0.1 --interval_max 1.0 --noise_scale 2.0 --t_eps 5e-2 --solver heun
88
- python run_jit_diffusers_inference.py --model_path ./JiT-L-32 --output_path ./demo_images/jit_l32_test_inference.png --class_label 207 --seed 42 --steps 50 --cfg 2.5 --interval_min 0.1 --interval_max 1.0 --noise_scale 2.0 --t_eps 5e-2 --solver heun
89
- python run_jit_diffusers_inference.py --model_path ./JiT-H-32 --output_path ./demo_images/jit_h32_test_inference.png --class_label 207 --seed 42 --steps 50 --cfg 2.3 --interval_min 0.1 --interval_max 1.0 --noise_scale 2.0 --t_eps 5e-2 --solver heun
90
- ```
 
14
  - en
15
  ---
16
 
17
+ # JiT-diffusers
18
 
19
+ Native diffusers implementation of **JiT** (Just image Transformer). Each variant folder is self-contained:
20
 
21
+ - `pipeline.py` `JiTPipeline`
22
+ - `scheduler/scheduling_jit.py` — `JiTScheduler` (linear `t in [0, 1]`, Heun/Euler)
23
+ - `transformer/jit_transformer_2d.py` — `JiTTransformer2DModel`
24
 
25
+ Shared ImageNet-1k labels live in [`labels/`](labels/) at the repo root (not duplicated per variant).
26
 
27
+ No separate `jit_diffusers` package; only PyPI `diffusers` plus local custom code in the variant directory.
 
 
 
 
 
 
 
28
 
29
+ ## Available checkpoints
30
 
31
+ | Checkpoint | Path | Resolution | Recommended CFG |
32
+ |---|---|---|---|
33
+ | JiT-B/16 | `./JiT-B-16` | 256×256 | 3.0 |
34
+ | JiT-L/16 | `./JiT-L-16` | 256×256 | 2.4 |
35
+ | JiT-H/16 | `./JiT-H-16` | 256×256 | 2.2 |
36
+ | JiT-B/32 | `./JiT-B-32` | 512×512 | 3.0 |
37
+ | JiT-L/32 | `./JiT-L-32` | 512×512 | 2.5 |
38
+ | JiT-H/32 | `./JiT-H-32` | 512×512 | 2.3 |
39
 
40
+ ## ImageNet class labels
41
 
42
+ | File | Direction | Format |
43
+ |---|---|---|
44
+ | `labels/id2label_en.json` | id → English | comma-separated synonyms, e.g. `"207": "golden retriever"` |
45
+ | `labels/id2label_cn.json` | id → Chinese | comma-separated synonyms, e.g. `"207": "金毛猎犬"` |
46
+
47
+ - `pipe.id2label` / `pipe.id2label_cn` — inspect id → label correspondence
48
+ - `pipe.labels` / `pipe.labels_cn` — reverse maps (synonym → id), sorted for browsing
49
+ - `pipe.get_label_ids("golden retriever")` or `pipe.get_label_ids("金毛猎犬", lang="cn")`
50
+ - `pipe(class_labels="golden retriever", ...)` — string labels resolved automatically
51
+
52
+ ## Inference
53
 
54
  ```python
55
+ from diffusers import DiffusionPipeline
 
56
  import torch
57
 
58
+ pipe = DiffusionPipeline.from_pretrained(
59
+ "./JiT-H-32",
60
+ trust_remote_code=True,
61
+ )
62
+ pipe.to("cuda")
63
+ pipe.transformer.to(dtype=torch.bfloat16)
64
 
65
+ # Numeric or human-readable labels
66
+ print(pipe.id2label[207])
67
+ print(pipe.get_label_ids("golden retriever"))
 
68
 
69
+ generator = torch.Generator(device="cuda").manual_seed(42)
70
+ images = pipe(
71
+ class_labels="golden retriever",
72
  num_inference_steps=50,
73
  guidance_scale=2.3,
 
 
 
 
74
  sampling_method="heun",
75
  generator=generator,
76
+ ).images
77
+ images[0].save("output.png")
 
 
 
 
 
78
  ```
79
 
80
+ Load a **variant subfolder** (e.g. `./JiT-H-32`), not the repo root.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo.png CHANGED

Git LFS Details

  • SHA256: d595ae2a4d665119949ee1c3930fd7a24befd51d4d4b1932a1a4c7e9e180f899
  • Pointer size: 131 Bytes
  • Size of remote file: 490 kB

Git LFS Details

  • SHA256: f5fdbd0300f895de7642229d1294aff74facd75c0bb4c4a01efa8c75b14b6fc4
  • Pointer size: 131 Bytes
  • Size of remote file: 470 kB
demo_images/jit_h32_final_test.png ADDED

Git LFS Details

  • SHA256: bc6804e8a82ad4873a6e9c9e2cf31a7ab901516184cca29955d12b59a45a8920
  • Pointer size: 131 Bytes
  • Size of remote file: 470 kB
demo_images/jit_h32_test_inference.png CHANGED

Git LFS Details

  • SHA256: d595ae2a4d665119949ee1c3930fd7a24befd51d4d4b1932a1a4c7e9e180f899
  • Pointer size: 131 Bytes
  • Size of remote file: 490 kB

Git LFS Details

  • SHA256: c3a6657c5ac1b7e50dec1dfff0fc02759b1b0e1d5cff75cb3db45af1391fef73
  • Pointer size: 131 Bytes
  • Size of remote file: 470 kB
labels/__pycache__/imagenet_labels.cpython-312.pyc ADDED
Binary file (3.24 kB). View file
 
labels/id2label_cn.json ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0": "丁鲷",
3
+ "1": "金鱼",
4
+ "2": "大白鲨",
5
+ "3": "虎鲨",
6
+ "4": "锤头鲨",
7
+ "5": "电鳐",
8
+ "6": "黄貂鱼",
9
+ "7": "公鸡",
10
+ "8": "母鸡",
11
+ "9": "鸵鸟",
12
+ "10": "燕雀",
13
+ "11": "金翅雀",
14
+ "12": "家朱雀",
15
+ "13": "灯芯草雀",
16
+ "14": "靛蓝雀,靛蓝鸟",
17
+ "15": "蓝鹀",
18
+ "16": "夜莺",
19
+ "17": "松鸦",
20
+ "18": "喜鹊",
21
+ "19": "山雀",
22
+ "20": "河鸟",
23
+ "21": "鸢(猛禽)",
24
+ "22": "秃头鹰",
25
+ "23": "秃鹫",
26
+ "24": "大灰猫头鹰",
27
+ "25": "欧洲火蝾螈",
28
+ "26": "普通蝾螈",
29
+ "27": "水蜥",
30
+ "28": "斑点蝾螈",
31
+ "29": "蝾螈,泥狗",
32
+ "30": "牛蛙",
33
+ "31": "树蛙",
34
+ "32": "尾蛙,铃蟾蜍,肋蟾蜍,尾蟾蜍",
35
+ "33": "红海龟",
36
+ "34": "皮革龟",
37
+ "35": "泥龟",
38
+ "36": "淡水龟",
39
+ "37": "箱龟",
40
+ "38": "带状壁虎",
41
+ "39": "普通鬣蜥",
42
+ "40": "美国变色龙",
43
+ "41": "鞭尾蜥蜴",
44
+ "42": "飞龙科蜥蜴",
45
+ "43": "褶边蜥蜴",
46
+ "44": "鳄鱼蜥蜴",
47
+ "45": "毒蜥",
48
+ "46": "绿蜥蜴",
49
+ "47": "非洲变色龙",
50
+ "48": "科莫多蜥蜴",
51
+ "49": "非洲鳄,尼罗河鳄鱼",
52
+ "50": "美国鳄鱼,鳄鱼",
53
+ "51": "三角龙",
54
+ "52": "雷蛇,蠕虫蛇",
55
+ "53": "环蛇,环颈蛇",
56
+ "54": "希腊蛇",
57
+ "55": "绿蛇,草蛇",
58
+ "56": "国王蛇",
59
+ "57": "袜带蛇,草蛇",
60
+ "58": "水蛇",
61
+ "59": "藤蛇",
62
+ "60": "夜蛇",
63
+ "61": "大蟒蛇",
64
+ "62": "岩石蟒蛇,岩蛇,蟒蛇",
65
+ "63": "印度眼镜蛇",
66
+ "64": "绿曼巴",
67
+ "65": "海蛇",
68
+ "66": "角腹蛇",
69
+ "67": "菱纹响尾蛇",
70
+ "68": "角响尾蛇",
71
+ "69": "三叶虫",
72
+ "70": "盲蜘蛛",
73
+ "71": "蝎子",
74
+ "72": "黑金花园蜘蛛",
75
+ "73": "谷仓蜘蛛",
76
+ "74": "花园蜘蛛",
77
+ "75": "黑寡妇蜘蛛",
78
+ "76": "狼蛛",
79
+ "77": "狼蜘蛛,狩猎蜘蛛",
80
+ "78": "壁虱",
81
+ "79": "蜈蚣",
82
+ "80": "黑松鸡",
83
+ "81": "松鸡,雷鸟",
84
+ "82": "披肩鸡,披肩榛鸡",
85
+ "83": "草原鸡,草原松鸡",
86
+ "84": "孔雀",
87
+ "85": "鹌鹑",
88
+ "86": "鹧鸪",
89
+ "87": "非洲灰鹦鹉",
90
+ "88": "金刚鹦鹉",
91
+ "89": "硫冠鹦鹉",
92
+ "90": "短尾鹦鹉",
93
+ "91": "褐翅鸦鹃",
94
+ "92": "蜜蜂",
95
+ "93": "犀鸟",
96
+ "94": "蜂鸟",
97
+ "95": "鹟䴕",
98
+ "96": "犀鸟",
99
+ "97": "野鸭",
100
+ "98": "红胸秋沙鸭",
101
+ "99": "鹅",
102
+ "100": "黑天鹅",
103
+ "101": "大象",
104
+ "102": "针鼹鼠",
105
+ "103": "鸭嘴兽",
106
+ "104": "沙袋鼠",
107
+ "105": "考拉,考拉熊",
108
+ "106": "袋熊",
109
+ "107": "水母",
110
+ "108": "海葵",
111
+ "109": "脑珊瑚",
112
+ "110": "扁形虫扁虫",
113
+ "111": "线虫,蛔虫",
114
+ "112": "海螺",
115
+ "113": "蜗牛",
116
+ "114": "鼻涕虫",
117
+ "115": "海参",
118
+ "116": "石鳖",
119
+ "117": "鹦鹉螺",
120
+ "118": "珍宝蟹",
121
+ "119": "石蟹",
122
+ "120": "招潮蟹",
123
+ "121": "帝王蟹,阿拉斯加蟹,阿拉斯加帝王蟹",
124
+ "122": "美国龙虾,缅因州龙虾",
125
+ "123": "大螯虾",
126
+ "124": "小龙虾",
127
+ "125": "寄居蟹",
128
+ "126": "等足目动物(明虾和螃蟹近亲)",
129
+ "127": "白鹳",
130
+ "128": "黑鹳",
131
+ "129": "鹭",
132
+ "130": "火烈鸟",
133
+ "131": "小蓝鹭",
134
+ "132": "美国鹭,大白鹭",
135
+ "133": "麻鸦",
136
+ "134": "鹤",
137
+ "135": "秧鹤",
138
+ "136": "欧洲水鸡,紫水鸡",
139
+ "137": "沼泽泥母鸡,水母鸡",
140
+ "138": "鸨",
141
+ "139": "红翻石鹬",
142
+ "140": "红背鹬,黑腹滨鹬",
143
+ "141": "红脚鹬",
144
+ "142": "半蹼鹬",
145
+ "143": "蛎鹬",
146
+ "144": "鹈鹕",
147
+ "145": "国王企鹅",
148
+ "146": "信天翁,大海鸟",
149
+ "147": "灰鲸",
150
+ "148": "杀人鲸,逆戟鲸,虎鲸",
151
+ "149": "海牛",
152
+ "150": "海狮",
153
+ "151": "奇瓦瓦",
154
+ "152": "日本猎犬",
155
+ "153": "马尔济斯犬",
156
+ "154": "狮子狗",
157
+ "155": "西施犬",
158
+ "156": "布莱尼姆猎犬",
159
+ "157": "巴比狗",
160
+ "158": "玩具犬",
161
+ "159": "罗得西亚长背猎狗",
162
+ "160": "阿富汗猎犬",
163
+ "161": "猎犬",
164
+ "162": "比格犬,猎兔犬",
165
+ "163": "侦探犬",
166
+ "164": "蓝色快狗",
167
+ "165": "黑褐猎浣熊犬",
168
+ "166": "沃克猎犬",
169
+ "167": "英国猎狐犬",
170
+ "168": "美洲赤狗",
171
+ "169": "俄罗斯猎狼犬",
172
+ "170": "爱尔兰猎狼犬",
173
+ "171": "意大利灰狗",
174
+ "172": "惠比特犬",
175
+ "173": "依比沙猎犬",
176
+ "174": "挪威猎犬",
177
+ "175": "奥达猎犬,水獭猎犬",
178
+ "176": "沙克犬,瞪羚猎犬",
179
+ "177": "苏格兰猎鹿犬,猎鹿犬",
180
+ "178": "威玛猎犬",
181
+ "179": "斯塔福德郡牛头梗,斯塔福德郡斗牛梗",
182
+ "180": "美国斯塔福德郡梗,美国比特斗牛梗,斗牛梗",
183
+ "181": "贝德灵顿梗",
184
+ "182": "边境梗",
185
+ "183": "凯丽蓝梗",
186
+ "184": "爱尔兰梗",
187
+ "185": "诺福克梗",
188
+ "186": "诺维奇梗",
189
+ "187": "约克郡梗",
190
+ "188": "刚毛猎狐梗",
191
+ "189": "莱克兰梗",
192
+ "190": "锡利哈姆梗",
193
+ "191": "艾尔谷犬",
194
+ "192": "凯恩梗",
195
+ "193": "澳大利亚梗",
196
+ "194": "丹迪丁蒙梗",
197
+ "195": "波士顿梗",
198
+ "196": "迷你雪纳瑞犬",
199
+ "197": "巨型雪纳瑞犬",
200
+ "198": "标准雪纳瑞犬",
201
+ "199": "苏格兰梗",
202
+ "200": "西藏梗,菊花狗",
203
+ "201": "丝毛梗",
204
+ "202": "软毛麦色梗",
205
+ "203": "西高地白梗",
206
+ "204": "拉萨阿普索犬",
207
+ "205": "平毛寻回犬",
208
+ "206": "卷毛寻回犬",
209
+ "207": "金毛猎犬",
210
+ "208": "拉布拉多猎犬",
211
+ "209": "乞沙比克猎犬",
212
+ "210": "德国短毛猎犬",
213
+ "211": "维兹拉犬",
214
+ "212": "英国谍犬",
215
+ "213": "爱尔兰雪达犬,红色猎犬",
216
+ "214": "戈登雪达犬",
217
+ "215": "布列塔尼犬猎犬",
218
+ "216": "黄毛,黄毛猎犬",
219
+ "217": "英国史宾格犬",
220
+ "218": "威尔士史宾格犬",
221
+ "219": "可卡犬,英国可卡犬",
222
+ "220": "萨塞克斯猎犬",
223
+ "221": "爱尔兰水猎犬",
224
+ "222": "哥威斯犬",
225
+ "223": "舒柏奇犬",
226
+ "224": "比利时牧羊犬",
227
+ "225": "马里努阿犬",
228
+ "226": "伯瑞犬",
229
+ "227": "凯尔皮犬",
230
+ "228": "匈牙利牧羊犬",
231
+ "229": "老英国牧羊犬",
232
+ "230": "喜乐蒂牧羊犬",
233
+ "231": "牧羊犬",
234
+ "232": "边境牧羊犬",
235
+ "233": "法兰德斯牧牛狗",
236
+ "234": "罗特韦尔犬",
237
+ "235": "德国牧羊犬,德国警犬,阿尔萨斯",
238
+ "236": "多伯曼犬,杜宾犬",
239
+ "237": "迷你杜宾犬",
240
+ "238": "大瑞士山地犬",
241
+ "239": "伯恩山犬",
242
+ "240": "Appenzeller狗",
243
+ "241": "EntleBucher狗",
244
+ "242": "拳师狗",
245
+ "243": "斗牛獒",
246
+ "244": "藏獒",
247
+ "245": "法国斗牛犬",
248
+ "246": "大丹犬",
249
+ "247": "圣伯纳德狗",
250
+ "248": "爱斯基摩犬,哈士奇",
251
+ "249": "雪橇犬,阿拉斯加爱斯基摩狗",
252
+ "250": "哈士奇",
253
+ "251": "达尔马提亚,教练车狗",
254
+ "252": "狮毛狗",
255
+ "253": "巴辛吉狗",
256
+ "254": "哈巴狗,狮子狗",
257
+ "255": "莱昂贝格狗",
258
+ "256": "纽芬兰岛狗",
259
+ "257": "大白熊犬",
260
+ "258": "萨摩耶犬",
261
+ "259": "博美犬",
262
+ "260": "松狮,松狮",
263
+ "261": "荷兰卷尾狮毛狗",
264
+ "262": "布鲁塞尔格林芬犬",
265
+ "263": "彭布洛克威尔士科基犬",
266
+ "264": "威尔士柯基犬",
267
+ "265": "玩具贵宾犬",
268
+ "266": "迷你贵宾犬",
269
+ "267": "标准贵宾犬",
270
+ "268": "墨西哥无毛犬",
271
+ "269": "灰狼",
272
+ "270": "白狼,北极狼",
273
+ "271": "红太狼,鬃狼,犬犬鲁弗斯",
274
+ "272": "狼,草原狼,刷狼,郊狼",
275
+ "273": "澳洲野狗,澳大利亚野犬",
276
+ "274": "豺",
277
+ "275": "非洲猎犬,土狼犬",
278
+ "276": "鬣狗",
279
+ "277": "红狐狸",
280
+ "278": "沙狐",
281
+ "279": "北极狐狸,白狐狸",
282
+ "280": "灰狐狸",
283
+ "281": "虎斑猫",
284
+ "282": "山猫,虎猫",
285
+ "283": "波斯猫",
286
+ "284": "暹罗暹罗猫,",
287
+ "285": "埃及猫",
288
+ "286": "美洲狮,美洲豹",
289
+ "287": "猞猁,山猫",
290
+ "288": "豹子",
291
+ "289": "雪豹",
292
+ "290": "美洲虎",
293
+ "291": "狮子",
294
+ "292": "老虎",
295
+ "293": "猎豹",
296
+ "294": "棕熊",
297
+ "295": "美洲黑熊",
298
+ "296": "冰熊,北极熊",
299
+ "297": "懒熊",
300
+ "298": "猫鼬",
301
+ "299": "猫鼬,海猫",
302
+ "300": "虎甲虫",
303
+ "301": "瓢虫",
304
+ "302": "土鳖虫",
305
+ "303": "天牛",
306
+ "304": "龟甲虫",
307
+ "305": "粪甲虫",
308
+ "306": "犀牛甲虫",
309
+ "307": "象甲",
310
+ "308": "苍蝇",
311
+ "309": "蜜蜂",
312
+ "310": "蚂蚁",
313
+ "311": "蚱蜢",
314
+ "312": "蟋蟀",
315
+ "313": "竹节虫",
316
+ "314": "蟑螂",
317
+ "315": "螳螂",
318
+ "316": "蝉",
319
+ "317": "叶蝉",
320
+ "318": "草蜻蛉",
321
+ "319": "蜻蜓",
322
+ "320": "豆娘,蜻蛉",
323
+ "321": "优红蛱蝶",
324
+ "322": "小环蝴蝶",
325
+ "323": "君主蝴蝶,大斑蝶",
326
+ "324": "菜粉蝶",
327
+ "325": "白蝴蝶",
328
+ "326": "灰蝶",
329
+ "327": "海星",
330
+ "328": "海胆",
331
+ "329": "海参,海黄瓜",
332
+ "330": "野兔",
333
+ "331": "兔",
334
+ "332": "安哥拉兔",
335
+ "333": "仓鼠",
336
+ "334": "刺猬,豪猪,",
337
+ "335": "黑松鼠",
338
+ "336": "土拨鼠",
339
+ "337": "海狸",
340
+ "338": "豚鼠,豚鼠",
341
+ "339": "栗色马",
342
+ "340": "斑马",
343
+ "341": "猪",
344
+ "342": "野猪",
345
+ "343": "疣猪",
346
+ "344": "河马",
347
+ "345": "牛",
348
+ "346": "水牛,亚洲水牛",
349
+ "347": "野牛",
350
+ "348": "公羊",
351
+ "349": "大角羊,洛矶山大角羊",
352
+ "350": "山羊",
353
+ "351": "狷羚",
354
+ "352": "黑斑羚",
355
+ "353": "瞪羚",
356
+ "354": "阿拉伯单峰骆驼,骆驼",
357
+ "355": "羊驼",
358
+ "356": "黄鼠狼",
359
+ "357": "水貂",
360
+ "358": "臭猫",
361
+ "359": "黑足鼬",
362
+ "360": "水獭",
363
+ "361": "臭鼬,木猫",
364
+ "362": "獾",
365
+ "363": "犰狳",
366
+ "364": "树懒",
367
+ "365": "猩猩,婆罗洲猩猩",
368
+ "366": "大猩猩",
369
+ "367": "黑猩猩",
370
+ "368": "长臂猿",
371
+ "369": "合趾猿长臂猿,合趾猿",
372
+ "370": "长尾猴",
373
+ "371": "赤猴",
374
+ "372": "狒狒",
375
+ "373": "恒河猴,猕猴",
376
+ "374": "白头叶猴",
377
+ "375": "疣猴",
378
+ "376": "长鼻猴",
379
+ "377": "狨(美洲产小型长尾猴)",
380
+ "378": "卷尾猴",
381
+ "379": "吼猴",
382
+ "380": "伶猴",
383
+ "381": "蜘蛛猴",
384
+ "382": "松鼠猴",
385
+ "383": "马达加斯加环尾狐猴,鼠狐猴",
386
+ "384": "大狐猴,马达加斯加大狐猴",
387
+ "385": "印度大象,亚洲象",
388
+ "386": "非洲象,非洲象",
389
+ "387": "小熊猫",
390
+ "388": "大熊猫",
391
+ "389": "杖鱼",
392
+ "390": "鳗鱼",
393
+ "391": "银鲑,银鲑���",
394
+ "392": "三色刺蝶鱼",
395
+ "393": "海葵鱼",
396
+ "394": "鲟鱼",
397
+ "395": "雀鳝",
398
+ "396": "狮子鱼",
399
+ "397": "河豚",
400
+ "398": "算盘",
401
+ "399": "长袍",
402
+ "400": "学位袍",
403
+ "401": "手风琴",
404
+ "402": "原声吉他",
405
+ "403": "航空母舰",
406
+ "404": "客机",
407
+ "405": "飞艇",
408
+ "406": "祭坛",
409
+ "407": "救护车",
410
+ "408": "水陆两用车",
411
+ "409": "模拟时钟",
412
+ "410": "蜂房",
413
+ "411": "围裙",
414
+ "412": "垃圾桶",
415
+ "413": "攻击步枪,枪",
416
+ "414": "背包",
417
+ "415": "面包店,面包铺,",
418
+ "416": "平衡木",
419
+ "417": "热气球",
420
+ "418": "圆珠笔",
421
+ "419": "创可贴",
422
+ "420": "班卓琴",
423
+ "421": "栏杆,楼梯扶手",
424
+ "422": "杠铃",
425
+ "423": "理发师的椅子",
426
+ "424": "理发店",
427
+ "425": "牲口棚",
428
+ "426": "晴雨表",
429
+ "427": "圆筒",
430
+ "428": "园地小车,手推车",
431
+ "429": "棒球",
432
+ "430": "篮球",
433
+ "431": "婴儿床",
434
+ "432": "巴松管,低音管",
435
+ "433": "游泳帽",
436
+ "434": "沐浴毛巾",
437
+ "435": "浴缸,澡盆",
438
+ "436": "沙滩车,旅行车",
439
+ "437": "灯塔",
440
+ "438": "高脚杯",
441
+ "439": "熊皮高帽",
442
+ "440": "啤酒瓶",
443
+ "441": "啤酒杯",
444
+ "442": "钟塔",
445
+ "443": "(小儿用的)围嘴",
446
+ "444": "串联自行车,",
447
+ "445": "比基尼",
448
+ "446": "装订册",
449
+ "447": "双筒望远镜",
450
+ "448": "鸟舍",
451
+ "449": "船库",
452
+ "450": "雪橇",
453
+ "451": "饰扣式领带",
454
+ "452": "阔边女帽",
455
+ "453": "书橱",
456
+ "454": "书店,书摊",
457
+ "455": "瓶盖",
458
+ "456": "弓箭",
459
+ "457": "蝴蝶结领结",
460
+ "458": "铜制牌位",
461
+ "459": "奶罩",
462
+ "460": "防波堤,海堤",
463
+ "461": "铠甲",
464
+ "462": "扫帚",
465
+ "463": "桶",
466
+ "464": "扣环",
467
+ "465": "防弹背心",
468
+ "466": "动车,子弹头列车",
469
+ "467": "肉铺,肉菜市场",
470
+ "468": "出租车",
471
+ "469": "大锅",
472
+ "470": "蜡烛",
473
+ "471": "大炮",
474
+ "472": "独木舟",
475
+ "473": "开瓶器,开罐器",
476
+ "474": "开衫",
477
+ "475": "车镜",
478
+ "476": "旋转木马",
479
+ "477": "木匠的工具包,工具包",
480
+ "478": "纸箱",
481
+ "479": "车轮",
482
+ "480": "取款机,自动取款机",
483
+ "481": "盒式录音带",
484
+ "482": "卡带播放器",
485
+ "483": "城堡",
486
+ "484": "双体船",
487
+ "485": "CD播放器",
488
+ "486": "大提琴",
489
+ "487": "移动电话,手机",
490
+ "488": "铁链",
491
+ "489": "围栏",
492
+ "490": "链甲",
493
+ "491": "电锯,油锯",
494
+ "492": "箱子",
495
+ "493": "衣柜,洗脸台",
496
+ "494": "编钟,钟,锣",
497
+ "495": "中国橱柜",
498
+ "496": "圣诞袜",
499
+ "497": "教堂,教堂建筑",
500
+ "498": "电影院,剧场",
501
+ "499": "切肉刀,菜刀",
502
+ "500": "悬崖屋",
503
+ "501": "斗篷",
504
+ "502": "木屐,木鞋",
505
+ "503": "鸡尾酒调酒器",
506
+ "504": "咖啡杯",
507
+ "505": "咖啡壶",
508
+ "506": "螺旋结构(楼梯)",
509
+ "507": "组合锁",
510
+ "508": "电脑键盘,键盘",
511
+ "509": "糖果,糖果店",
512
+ "510": "集装箱船",
513
+ "511": "敞篷车",
514
+ "512": "开瓶器,瓶螺杆",
515
+ "513": "短号,喇叭",
516
+ "514": "牛仔靴",
517
+ "515": "牛仔帽",
518
+ "516": "摇篮",
519
+ "517": "起重机",
520
+ "518": "头盔",
521
+ "519": "板条箱",
522
+ "520": "小儿床",
523
+ "521": "砂锅",
524
+ "522": "槌球",
525
+ "523": "拐杖",
526
+ "524": "胸甲",
527
+ "525": "大坝,堤防",
528
+ "526": "书桌",
529
+ "527": "台式电脑",
530
+ "528": "有线电话",
531
+ "529": "尿布湿",
532
+ "530": "数字时钟",
533
+ "531": "数字手表",
534
+ "532": "餐桌板",
535
+ "533": "抹布",
536
+ "534": "洗碗机,洗碟机",
537
+ "535": "盘式制动器",
538
+ "536": "码头,船坞,码头设施",
539
+ "537": "狗拉雪橇",
540
+ "538": "圆顶",
541
+ "539": "门垫,垫子",
542
+ "540": "钻井平台,海上钻井",
543
+ "541": "鼓,乐器,鼓膜",
544
+ "542": "鼓槌",
545
+ "543": "哑铃",
546
+ "544": "荷兰烤箱",
547
+ "545": "电风扇,鼓风机",
548
+ "546": "电吉他",
549
+ "547": "电力机车",
550
+ "548": "电视,电视柜",
551
+ "549": "信封",
552
+ "550": "浓缩咖啡机",
553
+ "551": "扑面粉",
554
+ "552": "女用长围巾",
555
+ "553": "文件,文件柜,档案柜",
556
+ "554": "消防船",
557
+ "555": "消防车",
558
+ "556": "火炉栏",
559
+ "557": "旗杆",
560
+ "558": "长笛",
561
+ "559": "折叠椅",
562
+ "560": "橄榄球头盔",
563
+ "561": "叉车",
564
+ "562": "喷泉",
565
+ "563": "钢笔",
566
+ "564": "有四根帷柱的床",
567
+ "565": "运货车厢",
568
+ "566": "圆号,喇叭",
569
+ "567": "煎锅",
570
+ "568": "裘皮大衣",
571
+ "569": "垃圾车",
572
+ "570": "防毒面具,呼吸器",
573
+ "571": "汽油泵",
574
+ "572": "高脚杯",
575
+ "573": "卡丁车",
576
+ "574": "高尔夫球",
577
+ "575": "高尔夫球车",
578
+ "576": "狭长小船",
579
+ "577": "锣",
580
+ "578": "礼服",
581
+ "579": "钢琴",
582
+ "580": "温室,苗圃",
583
+ "581": "散热器格栅",
584
+ "582": "杂货店,食品市场",
585
+ "583": "断头台",
586
+ "584": "小发夹",
587
+ "585": "头发喷雾",
588
+ "586": "半履带装甲车",
589
+ "587": "锤子",
590
+ "588": "大篮子",
591
+ "589": "手摇鼓风机,吹风机",
592
+ "590": "手提电脑",
593
+ "591": "手帕",
594
+ "592": "硬盘",
595
+ "593": "口琴,口风琴",
596
+ "594": "竖琴",
597
+ "595": "收割机",
598
+ "596": "斧头",
599
+ "597": "手枪皮套",
600
+ "598": "家庭影院",
601
+ "599": "蜂窝",
602
+ "600": "钩爪",
603
+ "601": "衬裙",
604
+ "602": "单杠",
605
+ "603": "马车",
606
+ "604": "沙漏",
607
+ "605": "手机,iPad",
608
+ "606": "熨斗",
609
+ "607": "南瓜灯笼",
610
+ "608": "牛仔裤,蓝色牛仔裤",
611
+ "609": "吉普车",
612
+ "610": "运动衫,T恤",
613
+ "611": "拼图",
614
+ "612": "人力车",
615
+ "613": "操纵杆",
616
+ "614": "和服",
617
+ "615": "护膝",
618
+ "616": "蝴蝶结",
619
+ "617": "大褂,实验室外套",
620
+ "618": "长柄勺",
621
+ "619": "灯罩",
622
+ "620": "笔记本电脑",
623
+ "621": "割草机",
624
+ "622": "镜头盖",
625
+ "623": "开信刀,裁纸刀",
626
+ "624": "图书馆",
627
+ "625": "救生艇",
628
+ "626": "点火器,打火机",
629
+ "627": "豪华轿车",
630
+ "628": "远洋班轮",
631
+ "629": "唇膏,口红",
632
+ "630": "平底便鞋",
633
+ "631": "洗剂",
634
+ "632": "扬声器",
635
+ "633": "放大镜",
636
+ "634": "锯木厂",
637
+ "635": "磁罗盘",
638
+ "636": "邮袋",
639
+ "637": "信箱",
640
+ "638": "女游泳衣",
641
+ "639": "有肩带浴衣",
642
+ "640": "窨井盖",
643
+ "641": "沙球(一种打击乐器)",
644
+ "642": "马林巴木琴",
645
+ "643": "面膜",
646
+ "644": "火柴",
647
+ "645": "花柱",
648
+ "646": "迷宫",
649
+ "647": "量杯",
650
+ "648": "药箱",
651
+ "649": "巨石,巨石结构",
652
+ "650": "麦克风",
653
+ "651": "微波炉",
654
+ "652": "军装",
655
+ "653": "奶桶",
656
+ "654": "迷你巴士",
657
+ "655": "迷你裙",
658
+ "656": "面包车",
659
+ "657": "导弹",
660
+ "658": "连指手套",
661
+ "659": "搅拌钵",
662
+ "660": "活动房屋(由汽车拖拉的)",
663
+ "661": "T型发动机小汽车",
664
+ "662": "调制解调器",
665
+ "663": "修道院",
666
+ "664": "显示器",
667
+ "665": "电瓶车",
668
+ "666": "砂浆",
669
+ "667": "学士",
670
+ "668": "清真寺",
671
+ "669": "蚊帐",
672
+ "670": "摩托车",
673
+ "671": "山地自行车",
674
+ "672": "登山帐",
675
+ "673": "鼠标,电脑鼠标",
676
+ "674": "捕鼠器",
677
+ "675": "搬家车",
678
+ "676": "口套",
679
+ "677": "钉子",
680
+ "678": "颈托",
681
+ "679": "项链",
682
+ "680": "乳头(瓶)",
683
+ "681": "笔记本,笔记本电脑",
684
+ "682": "方尖碑",
685
+ "683": "双簧管",
686
+ "684": "陶笛,卵形笛",
687
+ "685": "里程表",
688
+ "686": "滤油器",
689
+ "687": "风琴,管风琴",
690
+ "688": "示波器",
691
+ "689": "罩裙",
692
+ "690": "牛车",
693
+ "691": "氧气面罩",
694
+ "692": "包装",
695
+ "693": "船桨",
696
+ "694": "明轮,桨轮",
697
+ "695": "挂锁,扣锁",
698
+ "696": "画笔",
699
+ "697": "睡衣",
700
+ "698": "宫殿",
701
+ "699": "排箫,鸣管",
702
+ "700": "纸巾",
703
+ "701": "降落伞",
704
+ "702": "双杠",
705
+ "703": "公园长椅",
706
+ "704": "停车收费表,停车计时器",
707
+ "705": "客车,教练车",
708
+ "706": "露台,阳台",
709
+ "707": "付费电话",
710
+ "708": "基座,基脚",
711
+ "709": "铅笔盒",
712
+ "710": "卷笔刀",
713
+ "711": "香水(瓶)",
714
+ "712": "培养皿",
715
+ "713": "复印机",
716
+ "714": "拨弦片,拨子",
717
+ "715": "尖顶头盔",
718
+ "716": "栅栏,栅栏",
719
+ "717": "皮卡,皮卡车",
720
+ "718": "桥墩",
721
+ "719": "存钱罐",
722
+ "720": "药瓶",
723
+ "721": "枕头",
724
+ "722": "乒乓球",
725
+ "723": "风车",
726
+ "724": "海盗船",
727
+ "725": "水罐",
728
+ "726": "木工刨",
729
+ "727": "天文馆",
730
+ "728": "塑料袋",
731
+ "729": "板架",
732
+ "730": "犁型铲雪机",
733
+ "731": "手压皮碗泵",
734
+ "732": "宝丽来相机",
735
+ "733": "电线杆",
736
+ "734": "警车,巡逻车",
737
+ "735": "雨披",
738
+ "736": "台球桌",
739
+ "737": "充气饮料瓶",
740
+ "738": "花盆",
741
+ "739": "陶工旋盘",
742
+ "740": "电钻",
743
+ "741": "祈祷垫,地毯",
744
+ "742": "打印机",
745
+ "743": "监狱",
746
+ "744": "炮弹,导弹",
747
+ "745": "投影仪",
748
+ "746": "冰球",
749
+ "747": "沙包,吊球",
750
+ "748": "钱包",
751
+ "749": "羽管笔",
752
+ "750": "被子",
753
+ "751": "赛车",
754
+ "752": "球拍",
755
+ "753": "散热器",
756
+ "754": "收音机",
757
+ "755": "射电望远镜,无线电反射器",
758
+ "756": "雨桶",
759
+ "757": "休闲车,房车",
760
+ "758": "卷轴,卷筒",
761
+ "759": "反射式照相机",
762
+ "760": "冰箱,冰柜",
763
+ "761": "遥控器",
764
+ "762": "餐厅,饮食店,食堂",
765
+ "763": "左轮手枪",
766
+ "764": "步枪",
767
+ "765": "摇椅",
768
+ "766": "电转烤肉架",
769
+ "767": "橡皮",
770
+ "768": "橄榄球",
771
+ "769": "直尺",
772
+ "770": "跑步鞋",
773
+ "771": "保险柜",
774
+ "772": "安全别针",
775
+ "773": "盐瓶(调味用)",
776
+ "774": "凉鞋",
777
+ "775": "纱笼,围裙",
778
+ "776": "萨克斯管",
779
+ "777": "剑鞘",
780
+ "778": "秤,称重机",
781
+ "779": "校车",
782
+ "780": "帆船",
783
+ "781": "记分牌",
784
+ "782": "屏幕",
785
+ "783": "螺丝",
786
+ "784": "螺丝刀",
787
+ "785": "安全带",
788
+ "786": "缝纫机",
789
+ "787": "盾牌,盾牌",
790
+ "788": "皮鞋店,鞋店",
791
+ "789": "障子",
792
+ "790": "购物篮",
793
+ "791": "购物车",
794
+ "792": "铁锹",
795
+ "793": "浴帽",
796
+ "794": "浴帘",
797
+ "795": "滑雪板",
798
+ "796": "滑雪面罩",
799
+ "797": "睡袋",
800
+ "798": "滑尺",
801
+ "799": "滑动门",
802
+ "800": "角子老虎机",
803
+ "801": "潜水通气管",
804
+ "802": "雪橇",
805
+ "803": "扫雪机,扫雪机",
806
+ "804": "皂液器",
807
+ "805": "足球",
808
+ "806": "袜子",
809
+ "807": "碟式太阳能,太阳能集热器,太阳能炉",
810
+ "808": "宽边帽",
811
+ "809": "汤碗",
812
+ "810": "空格键",
813
+ "811": "空间加热器",
814
+ "812": "航天飞机",
815
+ "813": "铲(搅拌或涂敷用的)",
816
+ "814": "快艇",
817
+ "815": "蜘蛛网",
818
+ "816": "纺锤,纱锭",
819
+ "817": "跑车",
820
+ "818": "聚光灯",
821
+ "819": "舞台",
822
+ "820": "蒸汽机车",
823
+ "821": "钢拱桥",
824
+ "822": "钢滚筒",
825
+ "823": "听诊器",
826
+ "824": "女用披肩",
827
+ "825": "石头墙",
828
+ "826": "秒表",
829
+ "827": "火炉",
830
+ "828": "过滤器",
831
+ "829": "有轨电车,电车",
832
+ "830": "担架",
833
+ "831": "沙发床",
834
+ "832": "佛塔",
835
+ "833": "潜艇,潜水艇",
836
+ "834": "套装,衣服",
837
+ "835": "日晷",
838
+ "836": "太阳镜",
839
+ "837": "太阳镜,墨镜",
840
+ "838": "防晒霜,防晒剂",
841
+ "839": "悬索桥",
842
+ "840": "拖把",
843
+ "841": "运动衫",
844
+ "842": "游泳裤",
845
+ "843": "秋千",
846
+ "844": "开关,电器开关",
847
+ "845": "注射器",
848
+ "846": "台灯",
849
+ "847": "坦克,装甲战车,装甲战斗车辆",
850
+ "848": "磁带播放器",
851
+ "849": "茶壶",
852
+ "850": "泰迪,泰迪熊",
853
+ "851": "电视",
854
+ "852": "网球",
855
+ "853": "茅草,茅草屋顶",
856
+ "854": "幕布,剧院的帷幕",
857
+ "855": "顶针",
858
+ "856": "脱粒机",
859
+ "857": "宝座",
860
+ "858": "瓦屋顶",
861
+ "859": "烤面包机",
862
+ "860": "烟草店,烟草",
863
+ "861": "马桶",
864
+ "862": "火炬",
865
+ "863": "图腾柱",
866
+ "864": "拖车,牵引车,清障车",
867
+ "865": "玩具店",
868
+ "866": "拖拉机",
869
+ "867": "拖车,铰接式卡车",
870
+ "868": "托盘",
871
+ "869": "风衣",
872
+ "870": "三轮车",
873
+ "871": "三体船",
874
+ "872": "三脚架",
875
+ "873": "凯旋门",
876
+ "874": "无轨电车",
877
+ "875": "长号",
878
+ "876": "浴盆,浴缸",
879
+ "877": "旋转式栅门",
880
+ "878": "打字机键盘",
881
+ "879": "伞",
882
+ "880": "独轮车",
883
+ "881": "直立式钢琴",
884
+ "882": "真空吸尘器",
885
+ "883": "花瓶",
886
+ "884": "拱顶",
887
+ "885": "天鹅绒",
888
+ "886": "自动售货机",
889
+ "887": "祭服",
890
+ "888": "高架桥",
891
+ "889": "小提琴,小提琴",
892
+ "890": "排球",
893
+ "891": "松饼机",
894
+ "892": "挂钟",
895
+ "893": "钱包,皮夹",
896
+ "894": "衣柜,壁橱",
897
+ "895": "军用飞机",
898
+ "896": "洗脸盆,洗手盆",
899
+ "897": "洗衣机,自动洗衣机",
900
+ "898": "水瓶",
901
+ "899": "水壶",
902
+ "900": "水塔",
903
+ "901": "威士忌壶",
904
+ "902": "哨子",
905
+ "903": "假发",
906
+ "904": "纱窗",
907
+ "905": "百叶窗",
908
+ "906": "温莎领带",
909
+ "907": "葡萄酒瓶",
910
+ "908": "飞机翅膀,飞机",
911
+ "909": "炒菜锅",
912
+ "910": "木制的勺子",
913
+ "911": "毛织品,羊绒",
914
+ "912": "栅栏,围栏",
915
+ "913": "沉船",
916
+ "914": "双桅船",
917
+ "915": "蒙古包",
918
+ "916": "网站,互联网网站",
919
+ "917": "漫画",
920
+ "918": "纵横字谜",
921
+ "919": "路标",
922
+ "920": "交通信号灯",
923
+ "921": "防尘罩,书皮",
924
+ "922": "菜单",
925
+ "923": "盘子",
926
+ "924": "鳄梨酱",
927
+ "925": "清汤",
928
+ "926": "罐焖土豆烧肉",
929
+ "927": "蛋糕",
930
+ "928": "冰淇淋",
931
+ "929": "雪糕,冰棍,冰棒",
932
+ "930": "法式面包",
933
+ "931": "百吉饼",
934
+ "932": "椒盐脆饼",
935
+ "933": "芝士汉堡",
936
+ "934": "热狗",
937
+ "935": "土豆泥",
938
+ "936": "结球甘蓝",
939
+ "937": "西兰花",
940
+ "938": "菜花",
941
+ "939": "绿皮密生西葫芦",
942
+ "940": "西葫芦",
943
+ "941": "小青南瓜",
944
+ "942": "南瓜",
945
+ "943": "黄瓜",
946
+ "944": "朝鲜蓟",
947
+ "945": "甜椒",
948
+ "946": "刺棘蓟",
949
+ "947": "蘑菇",
950
+ "948": "绿苹果",
951
+ "949": "草莓",
952
+ "950": "橘子",
953
+ "951": "柠檬",
954
+ "952": "无花果",
955
+ "953": "菠萝",
956
+ "954": "香蕉",
957
+ "955": "菠萝蜜",
958
+ "956": "蛋奶冻苹果",
959
+ "957": "石榴",
960
+ "958": "干草",
961
+ "959": "烤面条加干酪沙司",
962
+ "960": "巧克力酱,巧克力糖浆",
963
+ "961": "面团",
964
+ "962": "瑞士肉包,肉饼",
965
+ "963": "披萨,披萨饼",
966
+ "964": "馅饼",
967
+ "965": "卷饼",
968
+ "966": "红葡萄酒",
969
+ "967": "意大利浓咖啡",
970
+ "968": "杯子",
971
+ "969": "蛋酒",
972
+ "970": "高山",
973
+ "971": "泡泡",
974
+ "972": "悬崖",
975
+ "973": "珊瑚礁",
976
+ "974": "间歇泉",
977
+ "975": "湖边,湖岸",
978
+ "976": "海角",
979
+ "977": "沙洲,沙坝",
980
+ "978": "海滨,海岸",
981
+ "979": "峡谷",
982
+ "980": "火山",
983
+ "981": "棒球,棒球运动员",
984
+ "982": "新郎",
985
+ "983": "潜水员",
986
+ "984": "油菜",
987
+ "985": "雏菊",
988
+ "986": "杓兰",
989
+ "987": "玉米",
990
+ "988": "橡子",
991
+ "989": "玫瑰果",
992
+ "990": "七叶树果实",
993
+ "991": "珊瑚菌",
994
+ "992": "木耳",
995
+ "993": "鹿花菌",
996
+ "994": "鬼笔菌",
997
+ "995": "地星(菌类)",
998
+ "996": "多叶奇果菌",
999
+ "997": "牛肝菌",
1000
+ "998": "玉米穗",
1001
+ "999": "卫生纸"
1002
+ }
labels/id2label_en.json ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0": "tench, Tinca tinca",
3
+ "1": "goldfish, Carassius auratus",
4
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
5
+ "3": "tiger shark, Galeocerdo cuvieri",
6
+ "4": "hammerhead, hammerhead shark",
7
+ "5": "electric ray, crampfish, numbfish, torpedo",
8
+ "6": "stingray",
9
+ "7": "cock",
10
+ "8": "hen",
11
+ "9": "ostrich, Struthio camelus",
12
+ "10": "brambling, Fringilla montifringilla",
13
+ "11": "goldfinch, Carduelis carduelis",
14
+ "12": "house finch, linnet, Carpodacus mexicanus",
15
+ "13": "junco, snowbird",
16
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
17
+ "15": "robin, American robin, Turdus migratorius",
18
+ "16": "bulbul",
19
+ "17": "jay",
20
+ "18": "magpie",
21
+ "19": "chickadee",
22
+ "20": "water ouzel, dipper",
23
+ "21": "kite",
24
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
25
+ "23": "vulture",
26
+ "24": "great grey owl, great gray owl, Strix nebulosa",
27
+ "25": "European fire salamander, Salamandra salamandra",
28
+ "26": "common newt, Triturus vulgaris",
29
+ "27": "eft",
30
+ "28": "spotted salamander, Ambystoma maculatum",
31
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
32
+ "30": "bullfrog, Rana catesbeiana",
33
+ "31": "tree frog, tree-frog",
34
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
35
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
36
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
37
+ "35": "mud turtle",
38
+ "36": "terrapin",
39
+ "37": "box turtle, box tortoise",
40
+ "38": "banded gecko",
41
+ "39": "common iguana, iguana, Iguana iguana",
42
+ "40": "American chameleon, anole, Anolis carolinensis",
43
+ "41": "whiptail, whiptail lizard",
44
+ "42": "agama",
45
+ "43": "frilled lizard, Chlamydosaurus kingi",
46
+ "44": "alligator lizard",
47
+ "45": "Gila monster, Heloderma suspectum",
48
+ "46": "green lizard, Lacerta viridis",
49
+ "47": "African chameleon, Chamaeleo chamaeleon",
50
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
51
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
52
+ "50": "American alligator, Alligator mississipiensis",
53
+ "51": "triceratops",
54
+ "52": "thunder snake, worm snake, Carphophis amoenus",
55
+ "53": "ringneck snake, ring-necked snake, ring snake",
56
+ "54": "hognose snake, puff adder, sand viper",
57
+ "55": "green snake, grass snake",
58
+ "56": "king snake, kingsnake",
59
+ "57": "garter snake, grass snake",
60
+ "58": "water snake",
61
+ "59": "vine snake",
62
+ "60": "night snake, Hypsiglena torquata",
63
+ "61": "boa constrictor, Constrictor constrictor",
64
+ "62": "rock python, rock snake, Python sebae",
65
+ "63": "Indian cobra, Naja naja",
66
+ "64": "green mamba",
67
+ "65": "sea snake",
68
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
69
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
70
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
71
+ "69": "trilobite",
72
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
73
+ "71": "scorpion",
74
+ "72": "black and gold garden spider, Argiope aurantia",
75
+ "73": "barn spider, Araneus cavaticus",
76
+ "74": "garden spider, Aranea diademata",
77
+ "75": "black widow, Latrodectus mactans",
78
+ "76": "tarantula",
79
+ "77": "wolf spider, hunting spider",
80
+ "78": "tick",
81
+ "79": "centipede",
82
+ "80": "black grouse",
83
+ "81": "ptarmigan",
84
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
85
+ "83": "prairie chicken, prairie grouse, prairie fowl",
86
+ "84": "peacock",
87
+ "85": "quail",
88
+ "86": "partridge",
89
+ "87": "African grey, African gray, Psittacus erithacus",
90
+ "88": "macaw",
91
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
92
+ "90": "lorikeet",
93
+ "91": "coucal",
94
+ "92": "bee eater",
95
+ "93": "hornbill",
96
+ "94": "hummingbird",
97
+ "95": "jacamar",
98
+ "96": "toucan",
99
+ "97": "drake",
100
+ "98": "red-breasted merganser, Mergus serrator",
101
+ "99": "goose",
102
+ "100": "black swan, Cygnus atratus",
103
+ "101": "tusker",
104
+ "102": "echidna, spiny anteater, anteater",
105
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
106
+ "104": "wallaby, brush kangaroo",
107
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
108
+ "106": "wombat",
109
+ "107": "jellyfish",
110
+ "108": "sea anemone, anemone",
111
+ "109": "brain coral",
112
+ "110": "flatworm, platyhelminth",
113
+ "111": "nematode, nematode worm, roundworm",
114
+ "112": "conch",
115
+ "113": "snail",
116
+ "114": "slug",
117
+ "115": "sea slug, nudibranch",
118
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
119
+ "117": "chambered nautilus, pearly nautilus, nautilus",
120
+ "118": "Dungeness crab, Cancer magister",
121
+ "119": "rock crab, Cancer irroratus",
122
+ "120": "fiddler crab",
123
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
124
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
125
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
126
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
127
+ "125": "hermit crab",
128
+ "126": "isopod",
129
+ "127": "white stork, Ciconia ciconia",
130
+ "128": "black stork, Ciconia nigra",
131
+ "129": "spoonbill",
132
+ "130": "flamingo",
133
+ "131": "little blue heron, Egretta caerulea",
134
+ "132": "American egret, great white heron, Egretta albus",
135
+ "133": "bittern",
136
+ "134": "crane",
137
+ "135": "limpkin, Aramus pictus",
138
+ "136": "European gallinule, Porphyrio porphyrio",
139
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
140
+ "138": "bustard",
141
+ "139": "ruddy turnstone, Arenaria interpres",
142
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
143
+ "141": "redshank, Tringa totanus",
144
+ "142": "dowitcher",
145
+ "143": "oystercatcher, oyster catcher",
146
+ "144": "pelican",
147
+ "145": "king penguin, Aptenodytes patagonica",
148
+ "146": "albatross, mollymawk",
149
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
150
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
151
+ "149": "dugong, Dugong dugon",
152
+ "150": "sea lion",
153
+ "151": "Chihuahua",
154
+ "152": "Japanese spaniel",
155
+ "153": "Maltese dog, Maltese terrier, Maltese",
156
+ "154": "Pekinese, Pekingese, Peke",
157
+ "155": "Shih-Tzu",
158
+ "156": "Blenheim spaniel",
159
+ "157": "papillon",
160
+ "158": "toy terrier",
161
+ "159": "Rhodesian ridgeback",
162
+ "160": "Afghan hound, Afghan",
163
+ "161": "basset, basset hound",
164
+ "162": "beagle",
165
+ "163": "bloodhound, sleuthhound",
166
+ "164": "bluetick",
167
+ "165": "black-and-tan coonhound",
168
+ "166": "Walker hound, Walker foxhound",
169
+ "167": "English foxhound",
170
+ "168": "redbone",
171
+ "169": "borzoi, Russian wolfhound",
172
+ "170": "Irish wolfhound",
173
+ "171": "Italian greyhound",
174
+ "172": "whippet",
175
+ "173": "Ibizan hound, Ibizan Podenco",
176
+ "174": "Norwegian elkhound, elkhound",
177
+ "175": "otterhound, otter hound",
178
+ "176": "Saluki, gazelle hound",
179
+ "177": "Scottish deerhound, deerhound",
180
+ "178": "Weimaraner",
181
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
182
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
183
+ "181": "Bedlington terrier",
184
+ "182": "Border terrier",
185
+ "183": "Kerry blue terrier",
186
+ "184": "Irish terrier",
187
+ "185": "Norfolk terrier",
188
+ "186": "Norwich terrier",
189
+ "187": "Yorkshire terrier",
190
+ "188": "wire-haired fox terrier",
191
+ "189": "Lakeland terrier",
192
+ "190": "Sealyham terrier, Sealyham",
193
+ "191": "Airedale, Airedale terrier",
194
+ "192": "cairn, cairn terrier",
195
+ "193": "Australian terrier",
196
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
197
+ "195": "Boston bull, Boston terrier",
198
+ "196": "miniature schnauzer",
199
+ "197": "giant schnauzer",
200
+ "198": "standard schnauzer",
201
+ "199": "Scotch terrier, Scottish terrier, Scottie",
202
+ "200": "Tibetan terrier, chrysanthemum dog",
203
+ "201": "silky terrier, Sydney silky",
204
+ "202": "soft-coated wheaten terrier",
205
+ "203": "West Highland white terrier",
206
+ "204": "Lhasa, Lhasa apso",
207
+ "205": "flat-coated retriever",
208
+ "206": "curly-coated retriever",
209
+ "207": "golden retriever",
210
+ "208": "Labrador retriever",
211
+ "209": "Chesapeake Bay retriever",
212
+ "210": "German short-haired pointer",
213
+ "211": "vizsla, Hungarian pointer",
214
+ "212": "English setter",
215
+ "213": "Irish setter, red setter",
216
+ "214": "Gordon setter",
217
+ "215": "Brittany spaniel",
218
+ "216": "clumber, clumber spaniel",
219
+ "217": "English springer, English springer spaniel",
220
+ "218": "Welsh springer spaniel",
221
+ "219": "cocker spaniel, English cocker spaniel, cocker",
222
+ "220": "Sussex spaniel",
223
+ "221": "Irish water spaniel",
224
+ "222": "kuvasz",
225
+ "223": "schipperke",
226
+ "224": "groenendael",
227
+ "225": "malinois",
228
+ "226": "briard",
229
+ "227": "kelpie",
230
+ "228": "komondor",
231
+ "229": "Old English sheepdog, bobtail",
232
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
233
+ "231": "collie",
234
+ "232": "Border collie",
235
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
236
+ "234": "Rottweiler",
237
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
238
+ "236": "Doberman, Doberman pinscher",
239
+ "237": "miniature pinscher",
240
+ "238": "Greater Swiss Mountain dog",
241
+ "239": "Bernese mountain dog",
242
+ "240": "Appenzeller",
243
+ "241": "EntleBucher",
244
+ "242": "boxer",
245
+ "243": "bull mastiff",
246
+ "244": "Tibetan mastiff",
247
+ "245": "French bulldog",
248
+ "246": "Great Dane",
249
+ "247": "Saint Bernard, St Bernard",
250
+ "248": "Eskimo dog, husky",
251
+ "249": "malamute, malemute, Alaskan malamute",
252
+ "250": "Siberian husky",
253
+ "251": "dalmatian, coach dog, carriage dog",
254
+ "252": "affenpinscher, monkey pinscher, monkey dog",
255
+ "253": "basenji",
256
+ "254": "pug, pug-dog",
257
+ "255": "Leonberg",
258
+ "256": "Newfoundland, Newfoundland dog",
259
+ "257": "Great Pyrenees",
260
+ "258": "Samoyed, Samoyede",
261
+ "259": "Pomeranian",
262
+ "260": "chow, chow chow",
263
+ "261": "keeshond",
264
+ "262": "Brabancon griffon",
265
+ "263": "Pembroke, Pembroke Welsh corgi",
266
+ "264": "Cardigan, Cardigan Welsh corgi",
267
+ "265": "toy poodle",
268
+ "266": "miniature poodle",
269
+ "267": "standard poodle",
270
+ "268": "Mexican hairless",
271
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
272
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
273
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
274
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
275
+ "273": "dingo, warrigal, warragal, Canis dingo",
276
+ "274": "dhole, Cuon alpinus",
277
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
278
+ "276": "hyena, hyaena",
279
+ "277": "red fox, Vulpes vulpes",
280
+ "278": "kit fox, Vulpes macrotis",
281
+ "279": "Arctic fox, white fox, Alopex lagopus",
282
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
283
+ "281": "tabby, tabby cat",
284
+ "282": "tiger cat",
285
+ "283": "Persian cat",
286
+ "284": "Siamese cat, Siamese",
287
+ "285": "Egyptian cat",
288
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
289
+ "287": "lynx, catamount",
290
+ "288": "leopard, Panthera pardus",
291
+ "289": "snow leopard, ounce, Panthera uncia",
292
+ "290": "jaguar, panther, Panthera onca, Felis onca",
293
+ "291": "lion, king of beasts, Panthera leo",
294
+ "292": "tiger, Panthera tigris",
295
+ "293": "cheetah, chetah, Acinonyx jubatus",
296
+ "294": "brown bear, bruin, Ursus arctos",
297
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
298
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
299
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
300
+ "298": "mongoose",
301
+ "299": "meerkat, mierkat",
302
+ "300": "tiger beetle",
303
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
304
+ "302": "ground beetle, carabid beetle",
305
+ "303": "long-horned beetle, longicorn, longicorn beetle",
306
+ "304": "leaf beetle, chrysomelid",
307
+ "305": "dung beetle",
308
+ "306": "rhinoceros beetle",
309
+ "307": "weevil",
310
+ "308": "fly",
311
+ "309": "bee",
312
+ "310": "ant, emmet, pismire",
313
+ "311": "grasshopper, hopper",
314
+ "312": "cricket",
315
+ "313": "walking stick, walkingstick, stick insect",
316
+ "314": "cockroach, roach",
317
+ "315": "mantis, mantid",
318
+ "316": "cicada, cicala",
319
+ "317": "leafhopper",
320
+ "318": "lacewing, lacewing fly",
321
+ "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
322
+ "320": "damselfly",
323
+ "321": "admiral",
324
+ "322": "ringlet, ringlet butterfly",
325
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
326
+ "324": "cabbage butterfly",
327
+ "325": "sulphur butterfly, sulfur butterfly",
328
+ "326": "lycaenid, lycaenid butterfly",
329
+ "327": "starfish, sea star",
330
+ "328": "sea urchin",
331
+ "329": "sea cucumber, holothurian",
332
+ "330": "wood rabbit, cottontail, cottontail rabbit",
333
+ "331": "hare",
334
+ "332": "Angora, Angora rabbit",
335
+ "333": "hamster",
336
+ "334": "porcupine, hedgehog",
337
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
338
+ "336": "marmot",
339
+ "337": "beaver",
340
+ "338": "guinea pig, Cavia cobaya",
341
+ "339": "sorrel",
342
+ "340": "zebra",
343
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
344
+ "342": "wild boar, boar, Sus scrofa",
345
+ "343": "warthog",
346
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
347
+ "345": "ox",
348
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
349
+ "347": "bison",
350
+ "348": "ram, tup",
351
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
352
+ "350": "ibex, Capra ibex",
353
+ "351": "hartebeest",
354
+ "352": "impala, Aepyceros melampus",
355
+ "353": "gazelle",
356
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
357
+ "355": "llama",
358
+ "356": "weasel",
359
+ "357": "mink",
360
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
361
+ "359": "black-footed ferret, ferret, Mustela nigripes",
362
+ "360": "otter",
363
+ "361": "skunk, polecat, wood pussy",
364
+ "362": "badger",
365
+ "363": "armadillo",
366
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
367
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
368
+ "366": "gorilla, Gorilla gorilla",
369
+ "367": "chimpanzee, chimp, Pan troglodytes",
370
+ "368": "gibbon, Hylobates lar",
371
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
372
+ "370": "guenon, guenon monkey",
373
+ "371": "patas, hussar monkey, Erythrocebus patas",
374
+ "372": "baboon",
375
+ "373": "macaque",
376
+ "374": "langur",
377
+ "375": "colobus, colobus monkey",
378
+ "376": "proboscis monkey, Nasalis larvatus",
379
+ "377": "marmoset",
380
+ "378": "capuchin, ringtail, Cebus capucinus",
381
+ "379": "howler monkey, howler",
382
+ "380": "titi, titi monkey",
383
+ "381": "spider monkey, Ateles geoffroyi",
384
+ "382": "squirrel monkey, Saimiri sciureus",
385
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
386
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
387
+ "385": "Indian elephant, Elephas maximus",
388
+ "386": "African elephant, Loxodonta africana",
389
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
390
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
391
+ "389": "barracouta, snoek",
392
+ "390": "eel",
393
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
394
+ "392": "rock beauty, Holocanthus tricolor",
395
+ "393": "anemone fish",
396
+ "394": "sturgeon",
397
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
398
+ "396": "lionfish",
399
+ "397": "puffer, pufferfish, blowfish, globefish",
400
+ "398": "abacus",
401
+ "399": "abaya",
402
+ "400": "academic gown, academic robe, judge robe",
403
+ "401": "accordion, piano accordion, squeeze box",
404
+ "402": "acoustic guitar",
405
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
406
+ "404": "airliner",
407
+ "405": "airship, dirigible",
408
+ "406": "altar",
409
+ "407": "ambulance",
410
+ "408": "amphibian, amphibious vehicle",
411
+ "409": "analog clock",
412
+ "410": "apiary, bee house",
413
+ "411": "apron",
414
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
415
+ "413": "assault rifle, assault gun",
416
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
417
+ "415": "bakery, bakeshop, bakehouse",
418
+ "416": "balance beam, beam",
419
+ "417": "balloon",
420
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
421
+ "419": "Band Aid",
422
+ "420": "banjo",
423
+ "421": "bannister, banister, balustrade, balusters, handrail",
424
+ "422": "barbell",
425
+ "423": "barber chair",
426
+ "424": "barbershop",
427
+ "425": "barn",
428
+ "426": "barometer",
429
+ "427": "barrel, cask",
430
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
431
+ "429": "baseball",
432
+ "430": "basketball",
433
+ "431": "bassinet",
434
+ "432": "bassoon",
435
+ "433": "bathing cap, swimming cap",
436
+ "434": "bath towel",
437
+ "435": "bathtub, bathing tub, bath, tub",
438
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
439
+ "437": "beacon, lighthouse, beacon light, pharos",
440
+ "438": "beaker",
441
+ "439": "bearskin, busby, shako",
442
+ "440": "beer bottle",
443
+ "441": "beer glass",
444
+ "442": "bell cote, bell cot",
445
+ "443": "bib",
446
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
447
+ "445": "bikini, two-piece",
448
+ "446": "binder, ring-binder",
449
+ "447": "binoculars, field glasses, opera glasses",
450
+ "448": "birdhouse",
451
+ "449": "boathouse",
452
+ "450": "bobsled, bobsleigh, bob",
453
+ "451": "bolo tie, bolo, bola tie, bola",
454
+ "452": "bonnet, poke bonnet",
455
+ "453": "bookcase",
456
+ "454": "bookshop, bookstore, bookstall",
457
+ "455": "bottlecap",
458
+ "456": "bow",
459
+ "457": "bow tie, bow-tie, bowtie",
460
+ "458": "brass, memorial tablet, plaque",
461
+ "459": "brassiere, bra, bandeau",
462
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
463
+ "461": "breastplate, aegis, egis",
464
+ "462": "broom",
465
+ "463": "bucket, pail",
466
+ "464": "buckle",
467
+ "465": "bulletproof vest",
468
+ "466": "bullet train, bullet",
469
+ "467": "butcher shop, meat market",
470
+ "468": "cab, hack, taxi, taxicab",
471
+ "469": "caldron, cauldron",
472
+ "470": "candle, taper, wax light",
473
+ "471": "cannon",
474
+ "472": "canoe",
475
+ "473": "can opener, tin opener",
476
+ "474": "cardigan",
477
+ "475": "car mirror",
478
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
479
+ "477": "carpenters kit, tool kit",
480
+ "478": "carton",
481
+ "479": "car wheel",
482
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
483
+ "481": "cassette",
484
+ "482": "cassette player",
485
+ "483": "castle",
486
+ "484": "catamaran",
487
+ "485": "CD player",
488
+ "486": "cello, violoncello",
489
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
490
+ "488": "chain",
491
+ "489": "chainlink fence",
492
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
493
+ "491": "chain saw, chainsaw",
494
+ "492": "chest",
495
+ "493": "chiffonier, commode",
496
+ "494": "chime, bell, gong",
497
+ "495": "china cabinet, china closet",
498
+ "496": "Christmas stocking",
499
+ "497": "church, church building",
500
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
501
+ "499": "cleaver, meat cleaver, chopper",
502
+ "500": "cliff dwelling",
503
+ "501": "cloak",
504
+ "502": "clog, geta, patten, sabot",
505
+ "503": "cocktail shaker",
506
+ "504": "coffee mug",
507
+ "505": "coffeepot",
508
+ "506": "coil, spiral, volute, whorl, helix",
509
+ "507": "combination lock",
510
+ "508": "computer keyboard, keypad",
511
+ "509": "confectionery, confectionary, candy store",
512
+ "510": "container ship, containership, container vessel",
513
+ "511": "convertible",
514
+ "512": "corkscrew, bottle screw",
515
+ "513": "cornet, horn, trumpet, trump",
516
+ "514": "cowboy boot",
517
+ "515": "cowboy hat, ten-gallon hat",
518
+ "516": "cradle",
519
+ "517": "crane",
520
+ "518": "crash helmet",
521
+ "519": "crate",
522
+ "520": "crib, cot",
523
+ "521": "Crock Pot",
524
+ "522": "croquet ball",
525
+ "523": "crutch",
526
+ "524": "cuirass",
527
+ "525": "dam, dike, dyke",
528
+ "526": "desk",
529
+ "527": "desktop computer",
530
+ "528": "dial telephone, dial phone",
531
+ "529": "diaper, nappy, napkin",
532
+ "530": "digital clock",
533
+ "531": "digital watch",
534
+ "532": "dining table, board",
535
+ "533": "dishrag, dishcloth",
536
+ "534": "dishwasher, dish washer, dishwashing machine",
537
+ "535": "disk brake, disc brake",
538
+ "536": "dock, dockage, docking facility",
539
+ "537": "dogsled, dog sled, dog sleigh",
540
+ "538": "dome",
541
+ "539": "doormat, welcome mat",
542
+ "540": "drilling platform, offshore rig",
543
+ "541": "drum, membranophone, tympan",
544
+ "542": "drumstick",
545
+ "543": "dumbbell",
546
+ "544": "Dutch oven",
547
+ "545": "electric fan, blower",
548
+ "546": "electric guitar",
549
+ "547": "electric locomotive",
550
+ "548": "entertainment center",
551
+ "549": "envelope",
552
+ "550": "espresso maker",
553
+ "551": "face powder",
554
+ "552": "feather boa, boa",
555
+ "553": "file, file cabinet, filing cabinet",
556
+ "554": "fireboat",
557
+ "555": "fire engine, fire truck",
558
+ "556": "fire screen, fireguard",
559
+ "557": "flagpole, flagstaff",
560
+ "558": "flute, transverse flute",
561
+ "559": "folding chair",
562
+ "560": "football helmet",
563
+ "561": "forklift",
564
+ "562": "fountain",
565
+ "563": "fountain pen",
566
+ "564": "four-poster",
567
+ "565": "freight car",
568
+ "566": "French horn, horn",
569
+ "567": "frying pan, frypan, skillet",
570
+ "568": "fur coat",
571
+ "569": "garbage truck, dustcart",
572
+ "570": "gasmask, respirator, gas helmet",
573
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
574
+ "572": "goblet",
575
+ "573": "go-kart",
576
+ "574": "golf ball",
577
+ "575": "golfcart, golf cart",
578
+ "576": "gondola",
579
+ "577": "gong, tam-tam",
580
+ "578": "gown",
581
+ "579": "grand piano, grand",
582
+ "580": "greenhouse, nursery, glasshouse",
583
+ "581": "grille, radiator grille",
584
+ "582": "grocery store, grocery, food market, market",
585
+ "583": "guillotine",
586
+ "584": "hair slide",
587
+ "585": "hair spray",
588
+ "586": "half track",
589
+ "587": "hammer",
590
+ "588": "hamper",
591
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
592
+ "590": "hand-held computer, hand-held microcomputer",
593
+ "591": "handkerchief, hankie, hanky, hankey",
594
+ "592": "hard disc, hard disk, fixed disk",
595
+ "593": "harmonica, mouth organ, harp, mouth harp",
596
+ "594": "harp",
597
+ "595": "harvester, reaper",
598
+ "596": "hatchet",
599
+ "597": "holster",
600
+ "598": "home theater, home theatre",
601
+ "599": "honeycomb",
602
+ "600": "hook, claw",
603
+ "601": "hoopskirt, crinoline",
604
+ "602": "horizontal bar, high bar",
605
+ "603": "horse cart, horse-cart",
606
+ "604": "hourglass",
607
+ "605": "iPod",
608
+ "606": "iron, smoothing iron",
609
+ "607": "jack-o-lantern",
610
+ "608": "jean, blue jean, denim",
611
+ "609": "jeep, landrover",
612
+ "610": "jersey, T-shirt, tee shirt",
613
+ "611": "jigsaw puzzle",
614
+ "612": "jinrikisha, ricksha, rickshaw",
615
+ "613": "joystick",
616
+ "614": "kimono",
617
+ "615": "knee pad",
618
+ "616": "knot",
619
+ "617": "lab coat, laboratory coat",
620
+ "618": "ladle",
621
+ "619": "lampshade, lamp shade",
622
+ "620": "laptop, laptop computer",
623
+ "621": "lawn mower, mower",
624
+ "622": "lens cap, lens cover",
625
+ "623": "letter opener, paper knife, paperknife",
626
+ "624": "library",
627
+ "625": "lifeboat",
628
+ "626": "lighter, light, igniter, ignitor",
629
+ "627": "limousine, limo",
630
+ "628": "liner, ocean liner",
631
+ "629": "lipstick, lip rouge",
632
+ "630": "Loafer",
633
+ "631": "lotion",
634
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
635
+ "633": "loupe, jewelers loupe",
636
+ "634": "lumbermill, sawmill",
637
+ "635": "magnetic compass",
638
+ "636": "mailbag, postbag",
639
+ "637": "mailbox, letter box",
640
+ "638": "maillot",
641
+ "639": "maillot, tank suit",
642
+ "640": "manhole cover",
643
+ "641": "maraca",
644
+ "642": "marimba, xylophone",
645
+ "643": "mask",
646
+ "644": "matchstick",
647
+ "645": "maypole",
648
+ "646": "maze, labyrinth",
649
+ "647": "measuring cup",
650
+ "648": "medicine chest, medicine cabinet",
651
+ "649": "megalith, megalithic structure",
652
+ "650": "microphone, mike",
653
+ "651": "microwave, microwave oven",
654
+ "652": "military uniform",
655
+ "653": "milk can",
656
+ "654": "minibus",
657
+ "655": "miniskirt, mini",
658
+ "656": "minivan",
659
+ "657": "missile",
660
+ "658": "mitten",
661
+ "659": "mixing bowl",
662
+ "660": "mobile home, manufactured home",
663
+ "661": "Model T",
664
+ "662": "modem",
665
+ "663": "monastery",
666
+ "664": "monitor",
667
+ "665": "moped",
668
+ "666": "mortar",
669
+ "667": "mortarboard",
670
+ "668": "mosque",
671
+ "669": "mosquito net",
672
+ "670": "motor scooter, scooter",
673
+ "671": "mountain bike, all-terrain bike, off-roader",
674
+ "672": "mountain tent",
675
+ "673": "mouse, computer mouse",
676
+ "674": "mousetrap",
677
+ "675": "moving van",
678
+ "676": "muzzle",
679
+ "677": "nail",
680
+ "678": "neck brace",
681
+ "679": "necklace",
682
+ "680": "nipple",
683
+ "681": "notebook, notebook computer",
684
+ "682": "obelisk",
685
+ "683": "oboe, hautboy, hautbois",
686
+ "684": "ocarina, sweet potato",
687
+ "685": "odometer, hodometer, mileometer, milometer",
688
+ "686": "oil filter",
689
+ "687": "organ, pipe organ",
690
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
691
+ "689": "overskirt",
692
+ "690": "oxcart",
693
+ "691": "oxygen mask",
694
+ "692": "packet",
695
+ "693": "paddle, boat paddle",
696
+ "694": "paddlewheel, paddle wheel",
697
+ "695": "padlock",
698
+ "696": "paintbrush",
699
+ "697": "pajama, pyjama, pjs, jammies",
700
+ "698": "palace",
701
+ "699": "panpipe, pandean pipe, syrinx",
702
+ "700": "paper towel",
703
+ "701": "parachute, chute",
704
+ "702": "parallel bars, bars",
705
+ "703": "park bench",
706
+ "704": "parking meter",
707
+ "705": "passenger car, coach, carriage",
708
+ "706": "patio, terrace",
709
+ "707": "pay-phone, pay-station",
710
+ "708": "pedestal, plinth, footstall",
711
+ "709": "pencil box, pencil case",
712
+ "710": "pencil sharpener",
713
+ "711": "perfume, essence",
714
+ "712": "Petri dish",
715
+ "713": "photocopier",
716
+ "714": "pick, plectrum, plectron",
717
+ "715": "pickelhaube",
718
+ "716": "picket fence, paling",
719
+ "717": "pickup, pickup truck",
720
+ "718": "pier",
721
+ "719": "piggy bank, penny bank",
722
+ "720": "pill bottle",
723
+ "721": "pillow",
724
+ "722": "ping-pong ball",
725
+ "723": "pinwheel",
726
+ "724": "pirate, pirate ship",
727
+ "725": "pitcher, ewer",
728
+ "726": "plane, carpenters plane, woodworking plane",
729
+ "727": "planetarium",
730
+ "728": "plastic bag",
731
+ "729": "plate rack",
732
+ "730": "plow, plough",
733
+ "731": "plunger, plumbers helper",
734
+ "732": "Polaroid camera, Polaroid Land camera",
735
+ "733": "pole",
736
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
737
+ "735": "poncho",
738
+ "736": "pool table, billiard table, snooker table",
739
+ "737": "pop bottle, soda bottle",
740
+ "738": "pot, flowerpot",
741
+ "739": "potters wheel",
742
+ "740": "power drill",
743
+ "741": "prayer rug, prayer mat",
744
+ "742": "printer",
745
+ "743": "prison, prison house",
746
+ "744": "projectile, missile",
747
+ "745": "projector",
748
+ "746": "puck, hockey puck",
749
+ "747": "punching bag, punch bag, punching ball, punchball",
750
+ "748": "purse",
751
+ "749": "quill, quill pen",
752
+ "750": "quilt, comforter, comfort, puff",
753
+ "751": "racer, race car, racing car",
754
+ "752": "racket, racquet",
755
+ "753": "radiator",
756
+ "754": "radio, wireless",
757
+ "755": "radio telescope, radio reflector",
758
+ "756": "rain barrel",
759
+ "757": "recreational vehicle, RV, R.V.",
760
+ "758": "reel",
761
+ "759": "reflex camera",
762
+ "760": "refrigerator, icebox",
763
+ "761": "remote control, remote",
764
+ "762": "restaurant, eating house, eating place, eatery",
765
+ "763": "revolver, six-gun, six-shooter",
766
+ "764": "rifle",
767
+ "765": "rocking chair, rocker",
768
+ "766": "rotisserie",
769
+ "767": "rubber eraser, rubber, pencil eraser",
770
+ "768": "rugby ball",
771
+ "769": "rule, ruler",
772
+ "770": "running shoe",
773
+ "771": "safe",
774
+ "772": "safety pin",
775
+ "773": "saltshaker, salt shaker",
776
+ "774": "sandal",
777
+ "775": "sarong",
778
+ "776": "sax, saxophone",
779
+ "777": "scabbard",
780
+ "778": "scale, weighing machine",
781
+ "779": "school bus",
782
+ "780": "schooner",
783
+ "781": "scoreboard",
784
+ "782": "screen, CRT screen",
785
+ "783": "screw",
786
+ "784": "screwdriver",
787
+ "785": "seat belt, seatbelt",
788
+ "786": "sewing machine",
789
+ "787": "shield, buckler",
790
+ "788": "shoe shop, shoe-shop, shoe store",
791
+ "789": "shoji",
792
+ "790": "shopping basket",
793
+ "791": "shopping cart",
794
+ "792": "shovel",
795
+ "793": "shower cap",
796
+ "794": "shower curtain",
797
+ "795": "ski",
798
+ "796": "ski mask",
799
+ "797": "sleeping bag",
800
+ "798": "slide rule, slipstick",
801
+ "799": "sliding door",
802
+ "800": "slot, one-armed bandit",
803
+ "801": "snorkel",
804
+ "802": "snowmobile",
805
+ "803": "snowplow, snowplough",
806
+ "804": "soap dispenser",
807
+ "805": "soccer ball",
808
+ "806": "sock",
809
+ "807": "solar dish, solar collector, solar furnace",
810
+ "808": "sombrero",
811
+ "809": "soup bowl",
812
+ "810": "space bar",
813
+ "811": "space heater",
814
+ "812": "space shuttle",
815
+ "813": "spatula",
816
+ "814": "speedboat",
817
+ "815": "spider web, spiders web",
818
+ "816": "spindle",
819
+ "817": "sports car, sport car",
820
+ "818": "spotlight, spot",
821
+ "819": "stage",
822
+ "820": "steam locomotive",
823
+ "821": "steel arch bridge",
824
+ "822": "steel drum",
825
+ "823": "stethoscope",
826
+ "824": "stole",
827
+ "825": "stone wall",
828
+ "826": "stopwatch, stop watch",
829
+ "827": "stove",
830
+ "828": "strainer",
831
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
832
+ "830": "stretcher",
833
+ "831": "studio couch, day bed",
834
+ "832": "stupa, tope",
835
+ "833": "submarine, pigboat, sub, U-boat",
836
+ "834": "suit, suit of clothes",
837
+ "835": "sundial",
838
+ "836": "sunglass",
839
+ "837": "sunglasses, dark glasses, shades",
840
+ "838": "sunscreen, sunblock, sun blocker",
841
+ "839": "suspension bridge",
842
+ "840": "swab, swob, mop",
843
+ "841": "sweatshirt",
844
+ "842": "swimming trunks, bathing trunks",
845
+ "843": "swing",
846
+ "844": "switch, electric switch, electrical switch",
847
+ "845": "syringe",
848
+ "846": "table lamp",
849
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
850
+ "848": "tape player",
851
+ "849": "teapot",
852
+ "850": "teddy, teddy bear",
853
+ "851": "television, television system",
854
+ "852": "tennis ball",
855
+ "853": "thatch, thatched roof",
856
+ "854": "theater curtain, theatre curtain",
857
+ "855": "thimble",
858
+ "856": "thresher, thrasher, threshing machine",
859
+ "857": "throne",
860
+ "858": "tile roof",
861
+ "859": "toaster",
862
+ "860": "tobacco shop, tobacconist shop, tobacconist",
863
+ "861": "toilet seat",
864
+ "862": "torch",
865
+ "863": "totem pole",
866
+ "864": "tow truck, tow car, wrecker",
867
+ "865": "toyshop",
868
+ "866": "tractor",
869
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
870
+ "868": "tray",
871
+ "869": "trench coat",
872
+ "870": "tricycle, trike, velocipede",
873
+ "871": "trimaran",
874
+ "872": "tripod",
875
+ "873": "triumphal arch",
876
+ "874": "trolleybus, trolley coach, trackless trolley",
877
+ "875": "trombone",
878
+ "876": "tub, vat",
879
+ "877": "turnstile",
880
+ "878": "typewriter keyboard",
881
+ "879": "umbrella",
882
+ "880": "unicycle, monocycle",
883
+ "881": "upright, upright piano",
884
+ "882": "vacuum, vacuum cleaner",
885
+ "883": "vase",
886
+ "884": "vault",
887
+ "885": "velvet",
888
+ "886": "vending machine",
889
+ "887": "vestment",
890
+ "888": "viaduct",
891
+ "889": "violin, fiddle",
892
+ "890": "volleyball",
893
+ "891": "waffle iron",
894
+ "892": "wall clock",
895
+ "893": "wallet, billfold, notecase, pocketbook",
896
+ "894": "wardrobe, closet, press",
897
+ "895": "warplane, military plane",
898
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
899
+ "897": "washer, automatic washer, washing machine",
900
+ "898": "water bottle",
901
+ "899": "water jug",
902
+ "900": "water tower",
903
+ "901": "whiskey jug",
904
+ "902": "whistle",
905
+ "903": "wig",
906
+ "904": "window screen",
907
+ "905": "window shade",
908
+ "906": "Windsor tie",
909
+ "907": "wine bottle",
910
+ "908": "wing",
911
+ "909": "wok",
912
+ "910": "wooden spoon",
913
+ "911": "wool, woolen, woollen",
914
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
915
+ "913": "wreck",
916
+ "914": "yawl",
917
+ "915": "yurt",
918
+ "916": "web site, website, internet site, site",
919
+ "917": "comic book",
920
+ "918": "crossword puzzle, crossword",
921
+ "919": "street sign",
922
+ "920": "traffic light, traffic signal, stoplight",
923
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
924
+ "922": "menu",
925
+ "923": "plate",
926
+ "924": "guacamole",
927
+ "925": "consomme",
928
+ "926": "hot pot, hotpot",
929
+ "927": "trifle",
930
+ "928": "ice cream, icecream",
931
+ "929": "ice lolly, lolly, lollipop, popsicle",
932
+ "930": "French loaf",
933
+ "931": "bagel, beigel",
934
+ "932": "pretzel",
935
+ "933": "cheeseburger",
936
+ "934": "hotdog, hot dog, red hot",
937
+ "935": "mashed potato",
938
+ "936": "head cabbage",
939
+ "937": "broccoli",
940
+ "938": "cauliflower",
941
+ "939": "zucchini, courgette",
942
+ "940": "spaghetti squash",
943
+ "941": "acorn squash",
944
+ "942": "butternut squash",
945
+ "943": "cucumber, cuke",
946
+ "944": "artichoke, globe artichoke",
947
+ "945": "bell pepper",
948
+ "946": "cardoon",
949
+ "947": "mushroom",
950
+ "948": "Granny Smith",
951
+ "949": "strawberry",
952
+ "950": "orange",
953
+ "951": "lemon",
954
+ "952": "fig",
955
+ "953": "pineapple, ananas",
956
+ "954": "banana",
957
+ "955": "jackfruit, jak, jack",
958
+ "956": "custard apple",
959
+ "957": "pomegranate",
960
+ "958": "hay",
961
+ "959": "carbonara",
962
+ "960": "chocolate sauce, chocolate syrup",
963
+ "961": "dough",
964
+ "962": "meat loaf, meatloaf",
965
+ "963": "pizza, pizza pie",
966
+ "964": "potpie",
967
+ "965": "burrito",
968
+ "966": "red wine",
969
+ "967": "espresso",
970
+ "968": "cup",
971
+ "969": "eggnog",
972
+ "970": "alp",
973
+ "971": "bubble",
974
+ "972": "cliff, drop, drop-off",
975
+ "973": "coral reef",
976
+ "974": "geyser",
977
+ "975": "lakeside, lakeshore",
978
+ "976": "promontory, headland, head, foreland",
979
+ "977": "sandbar, sand bar",
980
+ "978": "seashore, coast, seacoast, sea-coast",
981
+ "979": "valley, vale",
982
+ "980": "volcano",
983
+ "981": "ballplayer, baseball player",
984
+ "982": "groom, bridegroom",
985
+ "983": "scuba diver",
986
+ "984": "rapeseed",
987
+ "985": "daisy",
988
+ "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
989
+ "987": "corn",
990
+ "988": "acorn",
991
+ "989": "hip, rose hip, rosehip",
992
+ "990": "buckeye, horse chestnut, conker",
993
+ "991": "coral fungus",
994
+ "992": "agaric",
995
+ "993": "gyromitra",
996
+ "994": "stinkhorn, carrion fungus",
997
+ "995": "earthstar",
998
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
999
+ "997": "bolete",
1000
+ "998": "ear, spike, capitulum",
1001
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1002
+ }