Instructions to use BiliSakura/JiT-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/JiT-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- JiT-B-16/model_index.json +9 -2
- JiT-B-16/pipeline.py +460 -0
- JiT-B-16/scheduler/scheduler_config.json +7 -0
- JiT-B-16/scheduler/scheduling_jit.py +161 -0
- JiT-B-16/transformer/config.json +18 -0
- JiT-B-16/transformer/diffusion_pytorch_model.safetensors +3 -0
- JiT-B-16/transformer/jit_transformer_2d.py +500 -0
- JiT-B-32/model_index.json +9 -2
- JiT-B-32/pipeline.py +460 -0
- JiT-B-32/scheduler/scheduler_config.json +7 -0
- JiT-B-32/scheduler/scheduling_jit.py +161 -0
- JiT-B-32/transformer/config.json +18 -0
- JiT-B-32/transformer/diffusion_pytorch_model.safetensors +3 -0
- JiT-B-32/transformer/jit_transformer_2d.py +500 -0
- JiT-H-16/model_index.json +9 -2
- JiT-H-16/pipeline.py +460 -0
- JiT-H-16/scheduler/scheduler_config.json +7 -0
- JiT-H-16/scheduler/scheduling_jit.py +161 -0
- JiT-H-16/transformer/config.json +18 -0
- JiT-H-16/transformer/diffusion_pytorch_model.safetensors +3 -0
- JiT-H-16/transformer/jit_transformer_2d.py +500 -0
- JiT-H-32/model_index.json +10 -3
- JiT-H-32/pipeline.py +460 -0
- JiT-H-32/scheduler/scheduler_config.json +7 -0
- JiT-H-32/scheduler/scheduling_jit.py +161 -0
- JiT-H-32/transformer/config.json +18 -0
- JiT-H-32/transformer/diffusion_pytorch_model.safetensors +3 -0
- JiT-H-32/transformer/jit_transformer_2d.py +500 -0
- JiT-L-16/model_index.json +9 -2
- JiT-L-16/pipeline.py +460 -0
- JiT-L-16/scheduler/scheduler_config.json +7 -0
- JiT-L-16/scheduler/scheduling_jit.py +161 -0
- JiT-L-16/transformer/config.json +18 -0
- JiT-L-16/transformer/diffusion_pytorch_model.safetensors +3 -0
- JiT-L-16/transformer/jit_transformer_2d.py +500 -0
- JiT-L-32/model_index.json +9 -2
- JiT-L-32/pipeline.py +460 -0
- JiT-L-32/scheduler/scheduler_config.json +7 -0
- JiT-L-32/scheduler/scheduling_jit.py +161 -0
- JiT-L-32/transformer/config.json +18 -0
- JiT-L-32/transformer/diffusion_pytorch_model.safetensors +3 -0
- JiT-L-32/transformer/jit_transformer_2d.py +500 -0
- README.md +44 -54
- demo.png +2 -2
- demo_images/jit_h32_final_test.png +3 -0
- demo_images/jit_h32_test_inference.png +2 -2
- labels/__pycache__/imagenet_labels.cpython-312.pyc +0 -0
- labels/id2label_cn.json +1002 -0
- labels/id2label_en.json +1002 -0
.gitattributes
CHANGED
|
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
demo.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
demo_images/jit_h32_test_inference.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
demo.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
demo_images/jit_h32_test_inference.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
demo_images/jit_h32_final_test.png filter=lfs diff=lfs merge=lfs -text
|
JiT-B-16/model_index.json
CHANGED
|
@@ -1,8 +1,15 @@
|
|
| 1 |
{
|
| 2 |
-
"_class_name":
|
|
|
|
|
|
|
|
|
|
| 3 |
"_diffusers_version": "0.36.0",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"transformer": [
|
| 5 |
-
"
|
| 6 |
"JiTTransformer2DModel"
|
| 7 |
]
|
| 8 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"_class_name": [
|
| 3 |
+
"pipeline",
|
| 4 |
+
"JiTPipeline"
|
| 5 |
+
],
|
| 6 |
"_diffusers_version": "0.36.0",
|
| 7 |
+
"scheduler": [
|
| 8 |
+
"scheduling_jit",
|
| 9 |
+
"JiTScheduler"
|
| 10 |
+
],
|
| 11 |
"transformer": [
|
| 12 |
+
"jit_transformer_2d",
|
| 13 |
"JiTTransformer2DModel"
|
| 14 |
]
|
| 15 |
}
|
JiT-B-16/pipeline.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import importlib
|
| 18 |
+
import json
|
| 19 |
+
import sys
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 26 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
RECOMMENDED_NOISE_BY_SIZE = {
|
| 30 |
+
256: 1.0,
|
| 31 |
+
512: 2.0,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class JiTPipeline(DiffusionPipeline):
|
| 36 |
+
r"""
|
| 37 |
+
Pipeline for image generation using JiT (Just image Transformer).
|
| 38 |
+
|
| 39 |
+
Parameters:
|
| 40 |
+
transformer ([`JiTTransformer2DModel`]):
|
| 41 |
+
A class-conditioned `JiTTransformer2DModel` to denoise the images.
|
| 42 |
+
scheduler ([`JiTScheduler`]):
|
| 43 |
+
Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler).
|
| 44 |
+
id2label (`dict[int, str]`, *optional*):
|
| 45 |
+
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
|
| 46 |
+
id2label_cn (`dict[int, str]`, *optional*):
|
| 47 |
+
ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
model_cpu_offload_seq = "transformer"
|
| 51 |
+
|
| 52 |
+
@classmethod
|
| 53 |
+
def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
|
| 54 |
+
"""Load a self-contained variant folder locally or from the Hub.
|
| 55 |
+
|
| 56 |
+
Examples:
|
| 57 |
+
JiTPipeline.from_pretrained(".")
|
| 58 |
+
JiTPipeline.from_pretrained("./JiT-H-32")
|
| 59 |
+
DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
|
| 60 |
+
"""
|
| 61 |
+
repo_root = Path(__file__).resolve().parent
|
| 62 |
+
|
| 63 |
+
if pretrained_model_name_or_path in (None, "", "."):
|
| 64 |
+
variant = repo_root
|
| 65 |
+
elif (
|
| 66 |
+
isinstance(pretrained_model_name_or_path, str)
|
| 67 |
+
and "/" in pretrained_model_name_or_path
|
| 68 |
+
and not Path(pretrained_model_name_or_path).exists()
|
| 69 |
+
):
|
| 70 |
+
from huggingface_hub import snapshot_download
|
| 71 |
+
|
| 72 |
+
hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
|
| 73 |
+
if subfolder:
|
| 74 |
+
hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"])
|
| 75 |
+
cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
|
| 76 |
+
variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
|
| 77 |
+
else:
|
| 78 |
+
variant = Path(pretrained_model_name_or_path)
|
| 79 |
+
if not variant.is_absolute():
|
| 80 |
+
candidate = (Path.cwd() / variant).resolve()
|
| 81 |
+
variant = candidate if candidate.exists() else (repo_root / variant).resolve()
|
| 82 |
+
if subfolder:
|
| 83 |
+
variant = variant / subfolder
|
| 84 |
+
|
| 85 |
+
model_kwargs = dict(kwargs)
|
| 86 |
+
inserted: List[str] = []
|
| 87 |
+
|
| 88 |
+
def _load_component(folder: str, module_name: str, class_name: str):
|
| 89 |
+
comp_dir = variant / folder
|
| 90 |
+
module_path = comp_dir / f"{module_name}.py"
|
| 91 |
+
has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
|
| 92 |
+
if not module_path.exists() or not has_weights:
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
comp_path = str(comp_dir)
|
| 96 |
+
if comp_path not in sys.path:
|
| 97 |
+
sys.path.insert(0, comp_path)
|
| 98 |
+
inserted.append(comp_path)
|
| 99 |
+
|
| 100 |
+
module = importlib.import_module(module_name)
|
| 101 |
+
component_cls = getattr(module, class_name)
|
| 102 |
+
return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
|
| 106 |
+
scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler")
|
| 107 |
+
|
| 108 |
+
if transformer is None:
|
| 109 |
+
raise ValueError(f"No loadable transformer found under {variant}")
|
| 110 |
+
|
| 111 |
+
variant_path = str(variant)
|
| 112 |
+
id2label, id2label_cn = cls._load_labels_for_variant(variant_path)
|
| 113 |
+
|
| 114 |
+
pipe = cls(
|
| 115 |
+
transformer=transformer,
|
| 116 |
+
scheduler=scheduler,
|
| 117 |
+
id2label=id2label,
|
| 118 |
+
id2label_cn=id2label_cn,
|
| 119 |
+
)
|
| 120 |
+
if variant_path and hasattr(pipe, "register_to_config"):
|
| 121 |
+
pipe.register_to_config(_name_or_path=variant_path)
|
| 122 |
+
return pipe
|
| 123 |
+
finally:
|
| 124 |
+
for comp_path in inserted:
|
| 125 |
+
if comp_path in sys.path:
|
| 126 |
+
sys.path.remove(comp_path)
|
| 127 |
+
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
transformer,
|
| 131 |
+
scheduler,
|
| 132 |
+
id2label: Optional[Dict[int, str]] = None,
|
| 133 |
+
id2label_cn: Optional[Dict[int, str]] = None,
|
| 134 |
+
):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.register_modules(transformer=transformer, scheduler=scheduler)
|
| 137 |
+
|
| 138 |
+
self._id2label = id2label or {}
|
| 139 |
+
self._id2label_cn = id2label_cn or {}
|
| 140 |
+
self.labels = self._build_label2id(self._id2label)
|
| 141 |
+
self.labels_cn = self._build_label2id(self._id2label_cn)
|
| 142 |
+
|
| 143 |
+
def _ensure_labels_loaded(self) -> None:
|
| 144 |
+
if self._id2label or self._id2label_cn:
|
| 145 |
+
return
|
| 146 |
+
loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None))
|
| 147 |
+
if loaded_en:
|
| 148 |
+
self._id2label = loaded_en
|
| 149 |
+
self.labels = self._build_label2id(self._id2label)
|
| 150 |
+
if loaded_cn:
|
| 151 |
+
self._id2label_cn = loaded_cn
|
| 152 |
+
self.labels_cn = self._build_label2id(self._id2label_cn)
|
| 153 |
+
|
| 154 |
+
@staticmethod
|
| 155 |
+
def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]:
|
| 156 |
+
if not variant_path:
|
| 157 |
+
return None
|
| 158 |
+
variant_dir = Path(variant_path).resolve()
|
| 159 |
+
labels_dir = variant_dir.parent / "labels"
|
| 160 |
+
return labels_dir if labels_dir.is_dir() else None
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]:
|
| 164 |
+
filename = "id2label_en.json" if lang == "en" else "id2label_cn.json"
|
| 165 |
+
path = labels_dir / filename
|
| 166 |
+
if not path.exists():
|
| 167 |
+
raise FileNotFoundError(path)
|
| 168 |
+
raw = json.loads(path.read_text(encoding="utf-8"))
|
| 169 |
+
return {int(key): value for key, value in raw.items()}
|
| 170 |
+
|
| 171 |
+
@classmethod
|
| 172 |
+
def _load_labels_for_variant(
|
| 173 |
+
cls,
|
| 174 |
+
variant_path: Optional[str],
|
| 175 |
+
) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]:
|
| 176 |
+
labels_dir = cls._labels_dir_for_variant(variant_path)
|
| 177 |
+
if labels_dir is None:
|
| 178 |
+
return None, None
|
| 179 |
+
try:
|
| 180 |
+
return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn")
|
| 181 |
+
except FileNotFoundError:
|
| 182 |
+
return None, None
|
| 183 |
+
|
| 184 |
+
@staticmethod
|
| 185 |
+
def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
|
| 186 |
+
label2id: Dict[str, int] = {}
|
| 187 |
+
for class_id, value in id2label.items():
|
| 188 |
+
for synonym in value.split(","):
|
| 189 |
+
synonym = synonym.strip()
|
| 190 |
+
if synonym:
|
| 191 |
+
label2id[synonym] = int(class_id)
|
| 192 |
+
return dict(sorted(label2id.items()))
|
| 193 |
+
|
| 194 |
+
@property
|
| 195 |
+
def id2label(self) -> Dict[int, str]:
|
| 196 |
+
"""ImageNet class id to English label string (comma-separated synonyms)."""
|
| 197 |
+
self._ensure_labels_loaded()
|
| 198 |
+
return self._id2label
|
| 199 |
+
|
| 200 |
+
@property
|
| 201 |
+
def id2label_cn(self) -> Dict[int, str]:
|
| 202 |
+
"""ImageNet class id to Chinese label string (comma-separated synonyms)."""
|
| 203 |
+
self._ensure_labels_loaded()
|
| 204 |
+
return self._id2label_cn
|
| 205 |
+
|
| 206 |
+
def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]:
|
| 207 |
+
r"""
|
| 208 |
+
Map ImageNet label strings to class ids.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
label (`str` or `list[str]`):
|
| 212 |
+
One or more label strings. Each string must match a synonym in `id2label` (English)
|
| 213 |
+
or `id2label_cn` (Chinese).
|
| 214 |
+
lang (`str`, *optional*, defaults to `"en"`):
|
| 215 |
+
`"en"` uses English synonyms; `"cn"` uses Chinese synonyms.
|
| 216 |
+
"""
|
| 217 |
+
if lang not in ("en", "cn"):
|
| 218 |
+
raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.")
|
| 219 |
+
|
| 220 |
+
self._ensure_labels_loaded()
|
| 221 |
+
label2id = self.labels if lang == "en" else self.labels_cn
|
| 222 |
+
if not label2id:
|
| 223 |
+
raise ValueError(
|
| 224 |
+
f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if isinstance(label, str):
|
| 228 |
+
label = [label]
|
| 229 |
+
|
| 230 |
+
missing = [item for item in label if item not in label2id]
|
| 231 |
+
if missing:
|
| 232 |
+
preview = ", ".join(list(label2id.keys())[:8])
|
| 233 |
+
raise ValueError(
|
| 234 |
+
f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
|
| 235 |
+
)
|
| 236 |
+
return [label2id[item] for item in label]
|
| 237 |
+
|
| 238 |
+
def _normalize_class_labels(
|
| 239 |
+
self,
|
| 240 |
+
class_labels: Union[int, str, List[Union[int, str]]],
|
| 241 |
+
) -> List[int]:
|
| 242 |
+
if isinstance(class_labels, int):
|
| 243 |
+
return [class_labels]
|
| 244 |
+
|
| 245 |
+
if isinstance(class_labels, str):
|
| 246 |
+
return self.get_label_ids(class_labels)
|
| 247 |
+
|
| 248 |
+
if class_labels and isinstance(class_labels[0], str):
|
| 249 |
+
self._ensure_labels_loaded()
|
| 250 |
+
if all(label in self.labels for label in class_labels):
|
| 251 |
+
return self.get_label_ids(class_labels, lang="en")
|
| 252 |
+
if all(label in self.labels_cn for label in class_labels):
|
| 253 |
+
return self.get_label_ids(class_labels, lang="cn")
|
| 254 |
+
raise ValueError(
|
| 255 |
+
"Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` "
|
| 256 |
+
"or Chinese synonyms from `pipe.labels_cn`."
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
return list(class_labels)
|
| 260 |
+
|
| 261 |
+
def _predict_velocity(
|
| 262 |
+
self,
|
| 263 |
+
z_value: torch.Tensor,
|
| 264 |
+
t: torch.Tensor,
|
| 265 |
+
class_labels: torch.Tensor,
|
| 266 |
+
class_null: torch.Tensor,
|
| 267 |
+
do_classifier_free_guidance: bool,
|
| 268 |
+
guidance_scale: float,
|
| 269 |
+
guidance_interval_min: float,
|
| 270 |
+
guidance_interval_max: float,
|
| 271 |
+
) -> torch.Tensor:
|
| 272 |
+
t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype)
|
| 273 |
+
if do_classifier_free_guidance:
|
| 274 |
+
z_in = torch.cat([z_value, z_value], dim=0)
|
| 275 |
+
labels = torch.cat([class_labels, class_null], dim=0)
|
| 276 |
+
else:
|
| 277 |
+
z_in = z_value
|
| 278 |
+
labels = class_labels
|
| 279 |
+
|
| 280 |
+
t_batch = t.flatten().expand(z_in.shape[0])
|
| 281 |
+
x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample
|
| 282 |
+
v = self.scheduler.velocity_from_prediction(z_in, x_pred, t)
|
| 283 |
+
|
| 284 |
+
if not do_classifier_free_guidance:
|
| 285 |
+
return v
|
| 286 |
+
|
| 287 |
+
v_cond, v_uncond = v.chunk(2, dim=0)
|
| 288 |
+
interval_mask = t < guidance_interval_max
|
| 289 |
+
if guidance_interval_min != 0.0:
|
| 290 |
+
interval_mask = interval_mask & (t > guidance_interval_min)
|
| 291 |
+
scale = torch.where(
|
| 292 |
+
interval_mask,
|
| 293 |
+
torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype),
|
| 294 |
+
torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype),
|
| 295 |
+
)
|
| 296 |
+
return v_uncond + scale * (v_cond - v_uncond)
|
| 297 |
+
|
| 298 |
+
def _run_sampler(
|
| 299 |
+
self,
|
| 300 |
+
latents: torch.Tensor,
|
| 301 |
+
class_labels: torch.Tensor,
|
| 302 |
+
class_null: torch.Tensor,
|
| 303 |
+
num_inference_steps: int,
|
| 304 |
+
do_classifier_free_guidance: bool,
|
| 305 |
+
guidance_scale: float,
|
| 306 |
+
guidance_interval_min: float,
|
| 307 |
+
guidance_interval_max: float,
|
| 308 |
+
sampling_method: str,
|
| 309 |
+
) -> torch.Tensor:
|
| 310 |
+
device = latents.device
|
| 311 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method)
|
| 312 |
+
timesteps = self.scheduler.timesteps
|
| 313 |
+
|
| 314 |
+
for i in self.progress_bar(range(num_inference_steps - 1)):
|
| 315 |
+
t = timesteps[i]
|
| 316 |
+
t_next = timesteps[i + 1]
|
| 317 |
+
v = self._predict_velocity(
|
| 318 |
+
latents,
|
| 319 |
+
t,
|
| 320 |
+
class_labels,
|
| 321 |
+
class_null,
|
| 322 |
+
do_classifier_free_guidance,
|
| 323 |
+
guidance_scale,
|
| 324 |
+
guidance_interval_min,
|
| 325 |
+
guidance_interval_max,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
if sampling_method == "heun":
|
| 329 |
+
latents_euler = latents + (t_next - t) * v
|
| 330 |
+
v_next = self._predict_velocity(
|
| 331 |
+
latents_euler,
|
| 332 |
+
t_next,
|
| 333 |
+
class_labels,
|
| 334 |
+
class_null,
|
| 335 |
+
do_classifier_free_guidance,
|
| 336 |
+
guidance_scale,
|
| 337 |
+
guidance_interval_min,
|
| 338 |
+
guidance_interval_max,
|
| 339 |
+
)
|
| 340 |
+
latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample
|
| 341 |
+
else:
|
| 342 |
+
latents = self.scheduler.step(v, t, latents).prev_sample
|
| 343 |
+
|
| 344 |
+
t = timesteps[-2]
|
| 345 |
+
t_next = timesteps[-1]
|
| 346 |
+
v = self._predict_velocity(
|
| 347 |
+
latents,
|
| 348 |
+
t,
|
| 349 |
+
class_labels,
|
| 350 |
+
class_null,
|
| 351 |
+
do_classifier_free_guidance,
|
| 352 |
+
guidance_scale,
|
| 353 |
+
guidance_interval_min,
|
| 354 |
+
guidance_interval_max,
|
| 355 |
+
)
|
| 356 |
+
return latents + (t_next - t) * v
|
| 357 |
+
|
| 358 |
+
@torch.inference_mode()
|
| 359 |
+
def __call__(
|
| 360 |
+
self,
|
| 361 |
+
class_labels: Union[int, str, List[Union[int, str]]],
|
| 362 |
+
guidance_scale: Optional[float] = None,
|
| 363 |
+
guidance_interval_min: float = 0.1,
|
| 364 |
+
guidance_interval_max: float = 1.0,
|
| 365 |
+
noise_scale: Optional[float] = None,
|
| 366 |
+
t_eps: Optional[float] = None,
|
| 367 |
+
sampling_method: Optional[str] = None,
|
| 368 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 369 |
+
num_inference_steps: int = 50,
|
| 370 |
+
output_type: Optional[str] = "pil",
|
| 371 |
+
return_dict: bool = True,
|
| 372 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
| 373 |
+
r"""
|
| 374 |
+
Generate class-conditional images.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
class_labels (`int`, `str`, `list[int]`, or `list[str]`):
|
| 378 |
+
ImageNet class indices or human-readable label strings (English or Chinese).
|
| 379 |
+
guidance_scale (`float`, *optional*):
|
| 380 |
+
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
|
| 381 |
+
guidance_interval_min (`float`, defaults to `0.1`):
|
| 382 |
+
Lower bound of the CFG interval in flow time `t in [0, 1]`.
|
| 383 |
+
guidance_interval_max (`float`, defaults to `1.0`):
|
| 384 |
+
Upper bound of the CFG interval in flow time.
|
| 385 |
+
noise_scale (`float`, *optional*):
|
| 386 |
+
Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
|
| 387 |
+
t_eps (`float`, *optional*):
|
| 388 |
+
Epsilon clamp for the `1 - t` denominator (scheduler config by default).
|
| 389 |
+
sampling_method (`str`, *optional*):
|
| 390 |
+
`"heun"` or `"euler"`. Defaults to the scheduler config (`heun`).
|
| 391 |
+
generator (`torch.Generator`, *optional*):
|
| 392 |
+
RNG for reproducibility.
|
| 393 |
+
num_inference_steps (`int`, defaults to `50`):
|
| 394 |
+
Number of solver steps (at least 2).
|
| 395 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 396 |
+
`"pil"`, `"np"`, or `"pt"`.
|
| 397 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 398 |
+
Return [`ImagePipelineOutput`] if True.
|
| 399 |
+
"""
|
| 400 |
+
solver = sampling_method or self.scheduler.config.solver
|
| 401 |
+
if solver not in {"heun", "euler"}:
|
| 402 |
+
raise ValueError("sampling_method must be one of: 'heun', 'euler'.")
|
| 403 |
+
if num_inference_steps < 2:
|
| 404 |
+
raise ValueError("num_inference_steps must be >= 2.")
|
| 405 |
+
|
| 406 |
+
if t_eps is not None:
|
| 407 |
+
self.scheduler.register_to_config(t_eps=t_eps)
|
| 408 |
+
|
| 409 |
+
class_label_ids = self._normalize_class_labels(class_labels)
|
| 410 |
+
do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
|
| 411 |
+
|
| 412 |
+
batch_size = len(class_label_ids)
|
| 413 |
+
image_size = int(self.transformer.config.sample_size)
|
| 414 |
+
channels = int(self.transformer.config.in_channels)
|
| 415 |
+
null_class_val = int(self.transformer.config.num_classes)
|
| 416 |
+
|
| 417 |
+
if guidance_scale is None:
|
| 418 |
+
guidance_scale = 1.0
|
| 419 |
+
if noise_scale is None:
|
| 420 |
+
noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0)
|
| 421 |
+
|
| 422 |
+
latents = (
|
| 423 |
+
randn_tensor(
|
| 424 |
+
shape=(batch_size, channels, image_size, image_size),
|
| 425 |
+
generator=generator,
|
| 426 |
+
device=self._execution_device,
|
| 427 |
+
dtype=self.transformer.dtype,
|
| 428 |
+
)
|
| 429 |
+
* noise_scale
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
|
| 433 |
+
class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
|
| 434 |
+
class_null = torch.full_like(class_labels_t, null_class_val)
|
| 435 |
+
|
| 436 |
+
latents = self._run_sampler(
|
| 437 |
+
latents,
|
| 438 |
+
class_labels_t,
|
| 439 |
+
class_null,
|
| 440 |
+
num_inference_steps,
|
| 441 |
+
do_classifier_free_guidance,
|
| 442 |
+
guidance_scale,
|
| 443 |
+
guidance_interval_min,
|
| 444 |
+
guidance_interval_max,
|
| 445 |
+
solver,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
|
| 449 |
+
if output_type == "pt":
|
| 450 |
+
images = images_pt
|
| 451 |
+
elif output_type == "np":
|
| 452 |
+
images = images_pt.permute(0, 2, 3, 1).numpy()
|
| 453 |
+
else:
|
| 454 |
+
images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy())
|
| 455 |
+
|
| 456 |
+
self.maybe_free_model_hooks()
|
| 457 |
+
|
| 458 |
+
if not return_dict:
|
| 459 |
+
return (images,)
|
| 460 |
+
return ImagePipelineOutput(images=images)
|
JiT-B-16/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "JiTScheduler",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"num_train_timesteps": 1000,
|
| 5 |
+
"t_eps": 0.05,
|
| 6 |
+
"solver": "heun"
|
| 7 |
+
}
|
JiT-B-16/scheduler/scheduling_jit.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 21 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
| 22 |
+
from diffusers.utils import BaseOutput
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class JiTSchedulerOutput(BaseOutput):
|
| 27 |
+
"""
|
| 28 |
+
Output class for the JiT scheduler's `step` function.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
prev_sample (`torch.Tensor`):
|
| 32 |
+
Updated sample after one solver step along the JiT flow-time grid.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
prev_sample: torch.Tensor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class JiTScheduler(SchedulerMixin, ConfigMixin):
|
| 39 |
+
"""
|
| 40 |
+
Manual flow-matching scheduler for JiT checkpoints.
|
| 41 |
+
|
| 42 |
+
Uses a linear flow-time grid `t in [0, 1]` (increasing), matching the official JiT
|
| 43 |
+
sampler. Velocity is `v = (x_pred - z) / (1 - t)`; integration is explicit Euler or
|
| 44 |
+
Heun along that grid.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
order = 2
|
| 48 |
+
|
| 49 |
+
@register_to_config
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
num_train_timesteps: int = 1000,
|
| 53 |
+
t_eps: float = 5e-2,
|
| 54 |
+
solver: str = "heun",
|
| 55 |
+
):
|
| 56 |
+
if solver not in {"heun", "euler"}:
|
| 57 |
+
raise ValueError("solver must be one of: 'heun', 'euler'.")
|
| 58 |
+
self.timesteps: Optional[torch.Tensor] = None
|
| 59 |
+
self.sigmas: Optional[List[float]] = None
|
| 60 |
+
self.num_inference_steps: Optional[int] = None
|
| 61 |
+
self._step_index: Optional[int] = None
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def init_noise_sigma(self) -> float:
|
| 65 |
+
return 1.0
|
| 66 |
+
|
| 67 |
+
def set_timesteps(
|
| 68 |
+
self,
|
| 69 |
+
num_inference_steps: int,
|
| 70 |
+
device: Union[str, torch.device, None] = None,
|
| 71 |
+
solver: Optional[str] = None,
|
| 72 |
+
) -> None:
|
| 73 |
+
if num_inference_steps < 2:
|
| 74 |
+
raise ValueError("num_inference_steps must be >= 2.")
|
| 75 |
+
|
| 76 |
+
self.num_inference_steps = num_inference_steps
|
| 77 |
+
self.timesteps = torch.linspace(
|
| 78 |
+
0.0,
|
| 79 |
+
1.0,
|
| 80 |
+
num_inference_steps + 1,
|
| 81 |
+
device=device,
|
| 82 |
+
dtype=torch.float32,
|
| 83 |
+
)
|
| 84 |
+
sigma_grid = torch.linspace(0.0, 1.0, num_inference_steps, device=device, dtype=torch.float32)
|
| 85 |
+
self.sigmas = (1.0 - sigma_grid).tolist()
|
| 86 |
+
self._step_index = 0
|
| 87 |
+
if solver is not None:
|
| 88 |
+
self.register_to_config(solver=solver)
|
| 89 |
+
|
| 90 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
| 91 |
+
del timestep
|
| 92 |
+
return sample
|
| 93 |
+
|
| 94 |
+
def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int:
|
| 95 |
+
if self._step_index is not None:
|
| 96 |
+
return self._step_index
|
| 97 |
+
if self.timesteps is None:
|
| 98 |
+
raise ValueError("Call `set_timesteps` before `step`.")
|
| 99 |
+
if timestep is None:
|
| 100 |
+
return 0
|
| 101 |
+
t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0])
|
| 102 |
+
matches = (self.timesteps - t_value).abs() < 1e-6
|
| 103 |
+
if matches.any():
|
| 104 |
+
return int(matches.nonzero(as_tuple=False)[0].item())
|
| 105 |
+
return 0
|
| 106 |
+
|
| 107 |
+
def step(
|
| 108 |
+
self,
|
| 109 |
+
model_output: torch.Tensor,
|
| 110 |
+
timestep: Union[float, torch.Tensor, None],
|
| 111 |
+
sample: torch.Tensor,
|
| 112 |
+
model_output_next: Optional[torch.Tensor] = None,
|
| 113 |
+
return_dict: bool = True,
|
| 114 |
+
) -> Union[JiTSchedulerOutput, Tuple[torch.Tensor]]:
|
| 115 |
+
"""
|
| 116 |
+
Integrate one step on the linear `t` grid.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
model_output (`torch.Tensor`):
|
| 120 |
+
Velocity `v = (x_pred - z) / (1 - t)` at the current time.
|
| 121 |
+
timestep (`float` or `torch.Tensor`, *optional*):
|
| 122 |
+
Current flow time `t`. When omitted, uses the internal step index.
|
| 123 |
+
sample (`torch.Tensor`):
|
| 124 |
+
Current noisy latent `z`.
|
| 125 |
+
model_output_next (`torch.Tensor`, *optional*):
|
| 126 |
+
Velocity at `t_next` (required for Heun intermediate steps).
|
| 127 |
+
"""
|
| 128 |
+
if self.timesteps is None:
|
| 129 |
+
raise ValueError("Call `set_timesteps` before `step`.")
|
| 130 |
+
|
| 131 |
+
step_index = self._resolve_step_index(timestep)
|
| 132 |
+
if step_index >= len(self.timesteps) - 1:
|
| 133 |
+
raise ValueError("Scheduler has already reached the final timestep.")
|
| 134 |
+
|
| 135 |
+
t = self.timesteps[step_index]
|
| 136 |
+
t_next = self.timesteps[step_index + 1]
|
| 137 |
+
dt = t_next - t
|
| 138 |
+
|
| 139 |
+
if self.config.solver == "heun" and model_output_next is not None:
|
| 140 |
+
prev_sample = sample + dt * 0.5 * (model_output + model_output_next)
|
| 141 |
+
else:
|
| 142 |
+
prev_sample = sample + dt * model_output
|
| 143 |
+
|
| 144 |
+
self._step_index = step_index + 1
|
| 145 |
+
|
| 146 |
+
if not return_dict:
|
| 147 |
+
return (prev_sample,)
|
| 148 |
+
return JiTSchedulerOutput(prev_sample=prev_sample)
|
| 149 |
+
|
| 150 |
+
def velocity_from_prediction(
|
| 151 |
+
self,
|
| 152 |
+
sample: torch.Tensor,
|
| 153 |
+
x_pred: torch.Tensor,
|
| 154 |
+
timestep: Union[float, torch.Tensor],
|
| 155 |
+
) -> torch.Tensor:
|
| 156 |
+
"""Compute JiT velocity `v = (x_pred - z) / (1 - t)` with denominator clamp."""
|
| 157 |
+
t = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype)
|
| 158 |
+
while t.ndim < sample.ndim:
|
| 159 |
+
t = t.unsqueeze(-1)
|
| 160 |
+
denom = (1.0 - t).clamp_min(self.config.t_eps)
|
| 161 |
+
return (x_pred - sample) / denom
|
JiT-B-16/transformer/config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "JiTTransformer2DModel",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"attention_dropout": 0.0,
|
| 5 |
+
"bottleneck_dim": 128,
|
| 6 |
+
"dropout": 0.0,
|
| 7 |
+
"hidden_size": 768,
|
| 8 |
+
"in_channels": 3,
|
| 9 |
+
"in_context_len": 32,
|
| 10 |
+
"in_context_start": 4,
|
| 11 |
+
"mlp_ratio": 4.0,
|
| 12 |
+
"norm_eps": 1e-06,
|
| 13 |
+
"num_attention_heads": 12,
|
| 14 |
+
"num_classes": 1000,
|
| 15 |
+
"num_layers": 12,
|
| 16 |
+
"patch_size": 16,
|
| 17 |
+
"sample_size": 256
|
| 18 |
+
}
|
JiT-B-16/transformer/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b68278f2e16a2842bbc17e7d38bc08d22475e1d748bb2e672a9b7e8aff5b4772
|
| 3 |
+
size 525298808
|
JiT-B-16/transformer/jit_transformer_2d.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 24 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 25 |
+
from diffusers.models.normalization import RMSNorm
|
| 26 |
+
from diffusers.utils import logging
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def broadcat(tensors, dim=-1):
|
| 33 |
+
num_tensors = len(tensors)
|
| 34 |
+
shape_lens = {len(t.shape) for t in tensors}
|
| 35 |
+
if len(shape_lens) != 1:
|
| 36 |
+
raise ValueError("tensors must all have the same number of dimensions")
|
| 37 |
+
shape_len = list(shape_lens)[0]
|
| 38 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
| 39 |
+
dims = list(zip(*(list(t.shape) for t in tensors)))
|
| 40 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
| 41 |
+
|
| 42 |
+
if not all(len(set(t[1])) <= 2 for t in expandable_dims):
|
| 43 |
+
raise ValueError("invalid dimensions for broadcastable concatenation")
|
| 44 |
+
|
| 45 |
+
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
| 46 |
+
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
| 47 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
| 48 |
+
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
| 49 |
+
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
| 50 |
+
return torch.cat(tensors, dim=dim)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def rotate_half(x):
|
| 54 |
+
x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2)
|
| 55 |
+
x1, x2 = x.unbind(dim=-1)
|
| 56 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 57 |
+
return x.view(*x.shape[:-2], -1)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class JiTRotaryEmbedding(nn.Module):
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
dim,
|
| 64 |
+
pt_seq_len=16,
|
| 65 |
+
ft_seq_len=None,
|
| 66 |
+
custom_freqs=None,
|
| 67 |
+
theta=10000,
|
| 68 |
+
num_cls_token=0,
|
| 69 |
+
):
|
| 70 |
+
super().__init__()
|
| 71 |
+
if custom_freqs is not None:
|
| 72 |
+
freqs = custom_freqs
|
| 73 |
+
else:
|
| 74 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 75 |
+
|
| 76 |
+
if ft_seq_len is None:
|
| 77 |
+
ft_seq_len = pt_seq_len
|
| 78 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
| 79 |
+
|
| 80 |
+
freqs = torch.einsum("..., f -> ... f", t, freqs)
|
| 81 |
+
freqs = freqs.repeat_interleave(2, dim=-1)
|
| 82 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
|
| 83 |
+
|
| 84 |
+
if num_cls_token > 0:
|
| 85 |
+
freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D]
|
| 86 |
+
cos_img = freqs_flat.cos()
|
| 87 |
+
sin_img = freqs_flat.sin()
|
| 88 |
+
|
| 89 |
+
# prepend in-context cls token
|
| 90 |
+
_, D = cos_img.shape
|
| 91 |
+
cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype)
|
| 92 |
+
sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype)
|
| 93 |
+
|
| 94 |
+
self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False)
|
| 95 |
+
self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False)
|
| 96 |
+
else:
|
| 97 |
+
self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False)
|
| 98 |
+
self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False)
|
| 99 |
+
|
| 100 |
+
def forward(self, t):
|
| 101 |
+
# Applied on (batch, seq_len, heads, head_dim) tensors from attention.
|
| 102 |
+
seq_len = t.shape[1]
|
| 103 |
+
freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
|
| 104 |
+
freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
|
| 105 |
+
|
| 106 |
+
return t * freqs_cos[:, None, :] + rotate_half(t) * freqs_sin[:, None, :]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def modulate(x, shift, scale):
|
| 110 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class JiTPatchEmbed(nn.Module):
|
| 114 |
+
"""Image to Patch Embedding with Bottleneck"""
|
| 115 |
+
|
| 116 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True):
|
| 117 |
+
super().__init__()
|
| 118 |
+
img_size = (img_size, img_size)
|
| 119 |
+
patch_size = (patch_size, patch_size)
|
| 120 |
+
self.img_size = img_size
|
| 121 |
+
self.patch_size = patch_size
|
| 122 |
+
self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
| 123 |
+
|
| 124 |
+
self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False)
|
| 125 |
+
self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias)
|
| 126 |
+
|
| 127 |
+
def forward(self, x):
|
| 128 |
+
x = self.proj2(self.proj1(x)).flatten(2).transpose(1, 2)
|
| 129 |
+
return x
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class JiTTimestepEmbedder(nn.Module):
|
| 133 |
+
"""
|
| 134 |
+
Embeds scalar timesteps into vector representations.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.mlp = nn.Sequential(
|
| 140 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 141 |
+
nn.SiLU(),
|
| 142 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 143 |
+
)
|
| 144 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 145 |
+
|
| 146 |
+
@staticmethod
|
| 147 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 148 |
+
"""
|
| 149 |
+
Create sinusoidal timestep embeddings.
|
| 150 |
+
"""
|
| 151 |
+
half = dim // 2
|
| 152 |
+
freqs = torch.exp(
|
| 153 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 154 |
+
).to(device=t.device)
|
| 155 |
+
args = t[:, None].float() * freqs[None]
|
| 156 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 157 |
+
if dim % 2:
|
| 158 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 159 |
+
return embedding
|
| 160 |
+
|
| 161 |
+
def forward(self, t, dtype=None):
|
| 162 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 163 |
+
if dtype is not None:
|
| 164 |
+
t_freq = t_freq.to(dtype=dtype)
|
| 165 |
+
t_emb = self.mlp(t_freq)
|
| 166 |
+
return t_emb
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class JiTLabelEmbedder(nn.Module):
|
| 170 |
+
"""
|
| 171 |
+
Embeds class labels into vector representations.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(self, num_classes, hidden_size):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.embedding_table = nn.Embedding(num_classes + 1, hidden_size)
|
| 177 |
+
self.num_classes = num_classes
|
| 178 |
+
|
| 179 |
+
def forward(self, labels):
|
| 180 |
+
embeddings = self.embedding_table(labels)
|
| 181 |
+
return embeddings
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class JiTAttention(nn.Module):
|
| 185 |
+
def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.num_heads = num_heads
|
| 188 |
+
head_dim = dim // num_heads
|
| 189 |
+
|
| 190 |
+
self.q_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
|
| 191 |
+
self.k_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
|
| 192 |
+
|
| 193 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 194 |
+
self.attn_drop = attn_drop
|
| 195 |
+
self.proj = nn.Linear(dim, dim)
|
| 196 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 197 |
+
|
| 198 |
+
def forward(self, x, rope=None):
|
| 199 |
+
B, N, C = x.shape
|
| 200 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 201 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 202 |
+
|
| 203 |
+
q = self.q_norm(q)
|
| 204 |
+
k = self.k_norm(k)
|
| 205 |
+
|
| 206 |
+
if rope is not None:
|
| 207 |
+
q = q.transpose(1, 2)
|
| 208 |
+
k = k.transpose(1, 2)
|
| 209 |
+
q = rope(q)
|
| 210 |
+
k = rope(k)
|
| 211 |
+
q = q.transpose(1, 2)
|
| 212 |
+
k = k.transpose(1, 2)
|
| 213 |
+
|
| 214 |
+
dropout_p = self.attn_drop if self.training else 0.0
|
| 215 |
+
x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
| 216 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 217 |
+
x = self.proj(x)
|
| 218 |
+
x = self.proj_drop(x)
|
| 219 |
+
return x
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class JiTSwiGLUFFN(nn.Module):
|
| 223 |
+
def __init__(self, dim: int, hidden_dim: int, drop=0.0, bias=True) -> None:
|
| 224 |
+
super().__init__()
|
| 225 |
+
hidden_dim = int(hidden_dim * 2 / 3)
|
| 226 |
+
self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias)
|
| 227 |
+
self.w3 = nn.Linear(hidden_dim, dim, bias=bias)
|
| 228 |
+
self.ffn_dropout = nn.Dropout(drop)
|
| 229 |
+
|
| 230 |
+
def forward(self, x):
|
| 231 |
+
x12 = self.w12(x)
|
| 232 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 233 |
+
hidden = F.silu(x1) * x2
|
| 234 |
+
return self.w3(self.ffn_dropout(hidden))
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class JiTBlock(nn.Module):
|
| 238 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
|
| 239 |
+
super().__init__()
|
| 240 |
+
self.norm1 = RMSNorm(hidden_size, eps=eps)
|
| 241 |
+
self.attn = JiTAttention(
|
| 242 |
+
hidden_size,
|
| 243 |
+
num_heads=num_heads,
|
| 244 |
+
qkv_bias=True,
|
| 245 |
+
qk_norm=True,
|
| 246 |
+
attn_drop=attn_drop,
|
| 247 |
+
proj_drop=proj_drop,
|
| 248 |
+
eps=eps,
|
| 249 |
+
)
|
| 250 |
+
self.norm2 = RMSNorm(hidden_size, eps=eps)
|
| 251 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 252 |
+
self.mlp = JiTSwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop)
|
| 253 |
+
|
| 254 |
+
self.act = nn.SiLU()
|
| 255 |
+
self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 256 |
+
|
| 257 |
+
def forward(self, x, c, feat_rope=None):
|
| 258 |
+
# Apply activation
|
| 259 |
+
c = self.act(c)
|
| 260 |
+
|
| 261 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
| 262 |
+
|
| 263 |
+
# Attention block
|
| 264 |
+
norm_x = self.norm1(x)
|
| 265 |
+
modulated_x = modulate(norm_x, shift_msa, scale_msa)
|
| 266 |
+
attn_out = self.attn(modulated_x, rope=feat_rope)
|
| 267 |
+
x = x + gate_msa.unsqueeze(1) * attn_out
|
| 268 |
+
|
| 269 |
+
# MLP block
|
| 270 |
+
norm_x = self.norm2(x)
|
| 271 |
+
modulated_x = modulate(norm_x, shift_mlp, scale_mlp)
|
| 272 |
+
mlp_out = self.mlp(modulated_x)
|
| 273 |
+
x = x + gate_mlp.unsqueeze(1) * mlp_out
|
| 274 |
+
|
| 275 |
+
return x
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
| 279 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 280 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 281 |
+
grid = np.meshgrid(grid_w, grid_h)
|
| 282 |
+
grid = np.stack(grid, axis=0)
|
| 283 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 284 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 285 |
+
if cls_token and extra_tokens > 0:
|
| 286 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 287 |
+
return pos_embed
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 291 |
+
if embed_dim % 2 != 0:
|
| 292 |
+
raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
|
| 293 |
+
|
| 294 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
|
| 295 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
|
| 296 |
+
emb = np.concatenate([emb_h, emb_w], axis=1)
|
| 297 |
+
return emb
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 301 |
+
if embed_dim % 2 != 0:
|
| 302 |
+
raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
|
| 303 |
+
|
| 304 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 305 |
+
omega /= embed_dim / 2.0
|
| 306 |
+
omega = 1.0 / 10000**omega
|
| 307 |
+
|
| 308 |
+
pos = pos.reshape(-1)
|
| 309 |
+
out = np.einsum("m,d->md", pos, omega)
|
| 310 |
+
|
| 311 |
+
emb_sin = np.sin(out)
|
| 312 |
+
emb_cos = np.cos(out)
|
| 313 |
+
|
| 314 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1)
|
| 315 |
+
return emb
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class JiTTransformer2DModel(ModelMixin, ConfigMixin):
|
| 319 |
+
r"""
|
| 320 |
+
A 2D Transformer for pixel-space class-conditional generation with JiT
|
| 321 |
+
([Back to Basics: Let Denoising Generative Models Denoise](https://arxiv.org/abs/2511.13720)).
|
| 322 |
+
|
| 323 |
+
Parameters:
|
| 324 |
+
sample_size (`int`, defaults to `256`):
|
| 325 |
+
Input image resolution (height and width).
|
| 326 |
+
patch_size (`int`, defaults to `16`):
|
| 327 |
+
Patch size for the bottleneck patch embedder.
|
| 328 |
+
in_channels (`int`, defaults to `3`):
|
| 329 |
+
Number of input image channels.
|
| 330 |
+
hidden_size (`int`, defaults to `768`):
|
| 331 |
+
Transformer hidden dimension.
|
| 332 |
+
num_layers (`int`, defaults to `12`):
|
| 333 |
+
Number of JiT transformer blocks.
|
| 334 |
+
num_attention_heads (`int`, defaults to `12`):
|
| 335 |
+
Number of attention heads per block.
|
| 336 |
+
mlp_ratio (`float`, defaults to `4.0`):
|
| 337 |
+
MLP hidden dimension multiplier.
|
| 338 |
+
attention_dropout (`float`, defaults to `0.0`):
|
| 339 |
+
Attention dropout in the middle quarter of blocks.
|
| 340 |
+
dropout (`float`, defaults to `0.0`):
|
| 341 |
+
Projection dropout in the middle quarter of blocks.
|
| 342 |
+
num_classes (`int`, defaults to `1000`):
|
| 343 |
+
Number of class labels (null label uses index `num_classes` for CFG).
|
| 344 |
+
bottleneck_dim (`int`, defaults to `128`):
|
| 345 |
+
PCA bottleneck dimension in the patch embedder.
|
| 346 |
+
in_context_len (`int`, defaults to `32`):
|
| 347 |
+
Number of in-context class tokens prepended mid-network.
|
| 348 |
+
in_context_start (`int`, defaults to `4`):
|
| 349 |
+
Block index at which in-context tokens are inserted.
|
| 350 |
+
norm_eps (`float`, defaults to `1e-6`):
|
| 351 |
+
Epsilon for RMSNorm layers.
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
_supports_gradient_checkpointing = True
|
| 355 |
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
| 356 |
+
|
| 357 |
+
@register_to_config
|
| 358 |
+
def __init__(
|
| 359 |
+
self,
|
| 360 |
+
sample_size: int = 256,
|
| 361 |
+
patch_size: int = 16,
|
| 362 |
+
in_channels: int = 3,
|
| 363 |
+
hidden_size: int = 768,
|
| 364 |
+
num_layers: int = 12,
|
| 365 |
+
num_attention_heads: int = 12,
|
| 366 |
+
mlp_ratio: float = 4.0,
|
| 367 |
+
attention_dropout: float = 0.0,
|
| 368 |
+
dropout: float = 0.0,
|
| 369 |
+
num_classes: int = 1000,
|
| 370 |
+
bottleneck_dim: int = 128,
|
| 371 |
+
in_context_len: int = 32,
|
| 372 |
+
in_context_start: int = 4,
|
| 373 |
+
norm_eps: float = 1e-6,
|
| 374 |
+
):
|
| 375 |
+
super().__init__()
|
| 376 |
+
self.sample_size = sample_size
|
| 377 |
+
self.patch_size = patch_size
|
| 378 |
+
self.in_channels = in_channels
|
| 379 |
+
self.out_channels = in_channels
|
| 380 |
+
self.hidden_size = hidden_size
|
| 381 |
+
self.num_layers = num_layers
|
| 382 |
+
self.num_attention_heads = num_attention_heads
|
| 383 |
+
self.in_context_len = in_context_len
|
| 384 |
+
self.in_context_start = in_context_start
|
| 385 |
+
self.norm_eps = norm_eps
|
| 386 |
+
self.gradient_checkpointing = False
|
| 387 |
+
|
| 388 |
+
# Time and Class Embedding
|
| 389 |
+
self.t_embedder = JiTTimestepEmbedder(hidden_size)
|
| 390 |
+
self.y_embedder = JiTLabelEmbedder(num_classes, hidden_size)
|
| 391 |
+
|
| 392 |
+
# Patch Embedding
|
| 393 |
+
self.x_embedder = JiTPatchEmbed(
|
| 394 |
+
img_size=sample_size,
|
| 395 |
+
patch_size=patch_size,
|
| 396 |
+
in_chans=in_channels,
|
| 397 |
+
pca_dim=bottleneck_dim,
|
| 398 |
+
embed_dim=hidden_size,
|
| 399 |
+
bias=True,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# Positional Embedding (Fixed Sin-Cos)
|
| 403 |
+
num_patches = self.x_embedder.num_patches
|
| 404 |
+
pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches**0.5))
|
| 405 |
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
|
| 406 |
+
|
| 407 |
+
# In-context Embedding
|
| 408 |
+
if self.in_context_len > 0:
|
| 409 |
+
self.in_context_posemb = nn.Parameter(torch.zeros(1, self.in_context_len, hidden_size))
|
| 410 |
+
|
| 411 |
+
# RoPE
|
| 412 |
+
half_head_dim = hidden_size // num_attention_heads // 2
|
| 413 |
+
hw_seq_len = sample_size // patch_size
|
| 414 |
+
self.feat_rope = JiTRotaryEmbedding(dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=0)
|
| 415 |
+
self.feat_rope_incontext = JiTRotaryEmbedding(
|
| 416 |
+
dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=self.in_context_len
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
# Blocks
|
| 420 |
+
self.blocks = nn.ModuleList(
|
| 421 |
+
[
|
| 422 |
+
JiTBlock(
|
| 423 |
+
hidden_size,
|
| 424 |
+
num_attention_heads,
|
| 425 |
+
mlp_ratio=mlp_ratio,
|
| 426 |
+
attn_drop=attention_dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
|
| 427 |
+
proj_drop=dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
|
| 428 |
+
eps=norm_eps,
|
| 429 |
+
)
|
| 430 |
+
for i in range(num_layers)
|
| 431 |
+
]
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# Final Layer
|
| 435 |
+
self.norm_final = RMSNorm(hidden_size, eps=norm_eps)
|
| 436 |
+
self.linear_final = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
|
| 437 |
+
self.act_final = nn.SiLU()
|
| 438 |
+
self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 439 |
+
|
| 440 |
+
def forward(
|
| 441 |
+
self,
|
| 442 |
+
hidden_states: torch.Tensor,
|
| 443 |
+
timestep: torch.LongTensor,
|
| 444 |
+
class_labels: torch.LongTensor,
|
| 445 |
+
return_dict: bool = True,
|
| 446 |
+
):
|
| 447 |
+
|
| 448 |
+
t_emb = self.t_embedder(timestep, dtype=hidden_states.dtype)
|
| 449 |
+
y_emb = self.y_embedder(class_labels)
|
| 450 |
+
|
| 451 |
+
# Ensure embeddings match hidden_states dtype
|
| 452 |
+
y_emb = y_emb.to(dtype=hidden_states.dtype)
|
| 453 |
+
|
| 454 |
+
c = t_emb + y_emb
|
| 455 |
+
|
| 456 |
+
# Patch Embed
|
| 457 |
+
x = self.x_embedder(hidden_states)
|
| 458 |
+
x = x + self.pos_embed.to(x.dtype)
|
| 459 |
+
|
| 460 |
+
# Blocks
|
| 461 |
+
for i, block in enumerate(self.blocks):
|
| 462 |
+
if self.in_context_len > 0 and i == self.in_context_start:
|
| 463 |
+
in_context_tokens = y_emb.unsqueeze(1).repeat(1, self.in_context_len, 1)
|
| 464 |
+
in_context_tokens = in_context_tokens + self.in_context_posemb.to(in_context_tokens.dtype)
|
| 465 |
+
x = torch.cat([in_context_tokens, x], dim=1)
|
| 466 |
+
|
| 467 |
+
rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
|
| 468 |
+
|
| 469 |
+
if self.training and self.gradient_checkpointing:
|
| 470 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 471 |
+
block,
|
| 472 |
+
x,
|
| 473 |
+
c,
|
| 474 |
+
rope,
|
| 475 |
+
use_reentrant=False,
|
| 476 |
+
)
|
| 477 |
+
else:
|
| 478 |
+
x = block(x, c, feat_rope=rope)
|
| 479 |
+
|
| 480 |
+
# Slice off in-context tokens
|
| 481 |
+
if self.in_context_len > 0:
|
| 482 |
+
x = x[:, self.in_context_len :]
|
| 483 |
+
|
| 484 |
+
# Final Layer
|
| 485 |
+
c = self.act_final(c)
|
| 486 |
+
shift, scale = self.adaLN_modulation_final(c).chunk(2, dim=1)
|
| 487 |
+
|
| 488 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 489 |
+
x = self.linear_final(x)
|
| 490 |
+
|
| 491 |
+
# Unpatchify
|
| 492 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 493 |
+
x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels))
|
| 494 |
+
x = torch.einsum("nhwpqc->nchpwq", x)
|
| 495 |
+
output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size))
|
| 496 |
+
|
| 497 |
+
if not return_dict:
|
| 498 |
+
return (output,)
|
| 499 |
+
|
| 500 |
+
return Transformer2DModelOutput(sample=output)
|
JiT-B-32/model_index.json
CHANGED
|
@@ -1,8 +1,15 @@
|
|
| 1 |
{
|
| 2 |
-
"_class_name":
|
|
|
|
|
|
|
|
|
|
| 3 |
"_diffusers_version": "0.36.0",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"transformer": [
|
| 5 |
-
"
|
| 6 |
"JiTTransformer2DModel"
|
| 7 |
]
|
| 8 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"_class_name": [
|
| 3 |
+
"pipeline",
|
| 4 |
+
"JiTPipeline"
|
| 5 |
+
],
|
| 6 |
"_diffusers_version": "0.36.0",
|
| 7 |
+
"scheduler": [
|
| 8 |
+
"scheduling_jit",
|
| 9 |
+
"JiTScheduler"
|
| 10 |
+
],
|
| 11 |
"transformer": [
|
| 12 |
+
"jit_transformer_2d",
|
| 13 |
"JiTTransformer2DModel"
|
| 14 |
]
|
| 15 |
}
|
JiT-B-32/pipeline.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import importlib
|
| 18 |
+
import json
|
| 19 |
+
import sys
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 26 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
RECOMMENDED_NOISE_BY_SIZE = {
|
| 30 |
+
256: 1.0,
|
| 31 |
+
512: 2.0,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class JiTPipeline(DiffusionPipeline):
|
| 36 |
+
r"""
|
| 37 |
+
Pipeline for image generation using JiT (Just image Transformer).
|
| 38 |
+
|
| 39 |
+
Parameters:
|
| 40 |
+
transformer ([`JiTTransformer2DModel`]):
|
| 41 |
+
A class-conditioned `JiTTransformer2DModel` to denoise the images.
|
| 42 |
+
scheduler ([`JiTScheduler`]):
|
| 43 |
+
Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler).
|
| 44 |
+
id2label (`dict[int, str]`, *optional*):
|
| 45 |
+
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
|
| 46 |
+
id2label_cn (`dict[int, str]`, *optional*):
|
| 47 |
+
ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
model_cpu_offload_seq = "transformer"
|
| 51 |
+
|
| 52 |
+
@classmethod
|
| 53 |
+
def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
|
| 54 |
+
"""Load a self-contained variant folder locally or from the Hub.
|
| 55 |
+
|
| 56 |
+
Examples:
|
| 57 |
+
JiTPipeline.from_pretrained(".")
|
| 58 |
+
JiTPipeline.from_pretrained("./JiT-H-32")
|
| 59 |
+
DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
|
| 60 |
+
"""
|
| 61 |
+
repo_root = Path(__file__).resolve().parent
|
| 62 |
+
|
| 63 |
+
if pretrained_model_name_or_path in (None, "", "."):
|
| 64 |
+
variant = repo_root
|
| 65 |
+
elif (
|
| 66 |
+
isinstance(pretrained_model_name_or_path, str)
|
| 67 |
+
and "/" in pretrained_model_name_or_path
|
| 68 |
+
and not Path(pretrained_model_name_or_path).exists()
|
| 69 |
+
):
|
| 70 |
+
from huggingface_hub import snapshot_download
|
| 71 |
+
|
| 72 |
+
hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
|
| 73 |
+
if subfolder:
|
| 74 |
+
hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"])
|
| 75 |
+
cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
|
| 76 |
+
variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
|
| 77 |
+
else:
|
| 78 |
+
variant = Path(pretrained_model_name_or_path)
|
| 79 |
+
if not variant.is_absolute():
|
| 80 |
+
candidate = (Path.cwd() / variant).resolve()
|
| 81 |
+
variant = candidate if candidate.exists() else (repo_root / variant).resolve()
|
| 82 |
+
if subfolder:
|
| 83 |
+
variant = variant / subfolder
|
| 84 |
+
|
| 85 |
+
model_kwargs = dict(kwargs)
|
| 86 |
+
inserted: List[str] = []
|
| 87 |
+
|
| 88 |
+
def _load_component(folder: str, module_name: str, class_name: str):
|
| 89 |
+
comp_dir = variant / folder
|
| 90 |
+
module_path = comp_dir / f"{module_name}.py"
|
| 91 |
+
has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
|
| 92 |
+
if not module_path.exists() or not has_weights:
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
comp_path = str(comp_dir)
|
| 96 |
+
if comp_path not in sys.path:
|
| 97 |
+
sys.path.insert(0, comp_path)
|
| 98 |
+
inserted.append(comp_path)
|
| 99 |
+
|
| 100 |
+
module = importlib.import_module(module_name)
|
| 101 |
+
component_cls = getattr(module, class_name)
|
| 102 |
+
return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
|
| 106 |
+
scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler")
|
| 107 |
+
|
| 108 |
+
if transformer is None:
|
| 109 |
+
raise ValueError(f"No loadable transformer found under {variant}")
|
| 110 |
+
|
| 111 |
+
variant_path = str(variant)
|
| 112 |
+
id2label, id2label_cn = cls._load_labels_for_variant(variant_path)
|
| 113 |
+
|
| 114 |
+
pipe = cls(
|
| 115 |
+
transformer=transformer,
|
| 116 |
+
scheduler=scheduler,
|
| 117 |
+
id2label=id2label,
|
| 118 |
+
id2label_cn=id2label_cn,
|
| 119 |
+
)
|
| 120 |
+
if variant_path and hasattr(pipe, "register_to_config"):
|
| 121 |
+
pipe.register_to_config(_name_or_path=variant_path)
|
| 122 |
+
return pipe
|
| 123 |
+
finally:
|
| 124 |
+
for comp_path in inserted:
|
| 125 |
+
if comp_path in sys.path:
|
| 126 |
+
sys.path.remove(comp_path)
|
| 127 |
+
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
transformer,
|
| 131 |
+
scheduler,
|
| 132 |
+
id2label: Optional[Dict[int, str]] = None,
|
| 133 |
+
id2label_cn: Optional[Dict[int, str]] = None,
|
| 134 |
+
):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.register_modules(transformer=transformer, scheduler=scheduler)
|
| 137 |
+
|
| 138 |
+
self._id2label = id2label or {}
|
| 139 |
+
self._id2label_cn = id2label_cn or {}
|
| 140 |
+
self.labels = self._build_label2id(self._id2label)
|
| 141 |
+
self.labels_cn = self._build_label2id(self._id2label_cn)
|
| 142 |
+
|
| 143 |
+
def _ensure_labels_loaded(self) -> None:
|
| 144 |
+
if self._id2label or self._id2label_cn:
|
| 145 |
+
return
|
| 146 |
+
loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None))
|
| 147 |
+
if loaded_en:
|
| 148 |
+
self._id2label = loaded_en
|
| 149 |
+
self.labels = self._build_label2id(self._id2label)
|
| 150 |
+
if loaded_cn:
|
| 151 |
+
self._id2label_cn = loaded_cn
|
| 152 |
+
self.labels_cn = self._build_label2id(self._id2label_cn)
|
| 153 |
+
|
| 154 |
+
@staticmethod
|
| 155 |
+
def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]:
|
| 156 |
+
if not variant_path:
|
| 157 |
+
return None
|
| 158 |
+
variant_dir = Path(variant_path).resolve()
|
| 159 |
+
labels_dir = variant_dir.parent / "labels"
|
| 160 |
+
return labels_dir if labels_dir.is_dir() else None
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]:
|
| 164 |
+
filename = "id2label_en.json" if lang == "en" else "id2label_cn.json"
|
| 165 |
+
path = labels_dir / filename
|
| 166 |
+
if not path.exists():
|
| 167 |
+
raise FileNotFoundError(path)
|
| 168 |
+
raw = json.loads(path.read_text(encoding="utf-8"))
|
| 169 |
+
return {int(key): value for key, value in raw.items()}
|
| 170 |
+
|
| 171 |
+
@classmethod
|
| 172 |
+
def _load_labels_for_variant(
|
| 173 |
+
cls,
|
| 174 |
+
variant_path: Optional[str],
|
| 175 |
+
) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]:
|
| 176 |
+
labels_dir = cls._labels_dir_for_variant(variant_path)
|
| 177 |
+
if labels_dir is None:
|
| 178 |
+
return None, None
|
| 179 |
+
try:
|
| 180 |
+
return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn")
|
| 181 |
+
except FileNotFoundError:
|
| 182 |
+
return None, None
|
| 183 |
+
|
| 184 |
+
@staticmethod
|
| 185 |
+
def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
|
| 186 |
+
label2id: Dict[str, int] = {}
|
| 187 |
+
for class_id, value in id2label.items():
|
| 188 |
+
for synonym in value.split(","):
|
| 189 |
+
synonym = synonym.strip()
|
| 190 |
+
if synonym:
|
| 191 |
+
label2id[synonym] = int(class_id)
|
| 192 |
+
return dict(sorted(label2id.items()))
|
| 193 |
+
|
| 194 |
+
@property
|
| 195 |
+
def id2label(self) -> Dict[int, str]:
|
| 196 |
+
"""ImageNet class id to English label string (comma-separated synonyms)."""
|
| 197 |
+
self._ensure_labels_loaded()
|
| 198 |
+
return self._id2label
|
| 199 |
+
|
| 200 |
+
@property
|
| 201 |
+
def id2label_cn(self) -> Dict[int, str]:
|
| 202 |
+
"""ImageNet class id to Chinese label string (comma-separated synonyms)."""
|
| 203 |
+
self._ensure_labels_loaded()
|
| 204 |
+
return self._id2label_cn
|
| 205 |
+
|
| 206 |
+
def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]:
|
| 207 |
+
r"""
|
| 208 |
+
Map ImageNet label strings to class ids.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
label (`str` or `list[str]`):
|
| 212 |
+
One or more label strings. Each string must match a synonym in `id2label` (English)
|
| 213 |
+
or `id2label_cn` (Chinese).
|
| 214 |
+
lang (`str`, *optional*, defaults to `"en"`):
|
| 215 |
+
`"en"` uses English synonyms; `"cn"` uses Chinese synonyms.
|
| 216 |
+
"""
|
| 217 |
+
if lang not in ("en", "cn"):
|
| 218 |
+
raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.")
|
| 219 |
+
|
| 220 |
+
self._ensure_labels_loaded()
|
| 221 |
+
label2id = self.labels if lang == "en" else self.labels_cn
|
| 222 |
+
if not label2id:
|
| 223 |
+
raise ValueError(
|
| 224 |
+
f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if isinstance(label, str):
|
| 228 |
+
label = [label]
|
| 229 |
+
|
| 230 |
+
missing = [item for item in label if item not in label2id]
|
| 231 |
+
if missing:
|
| 232 |
+
preview = ", ".join(list(label2id.keys())[:8])
|
| 233 |
+
raise ValueError(
|
| 234 |
+
f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
|
| 235 |
+
)
|
| 236 |
+
return [label2id[item] for item in label]
|
| 237 |
+
|
| 238 |
+
def _normalize_class_labels(
|
| 239 |
+
self,
|
| 240 |
+
class_labels: Union[int, str, List[Union[int, str]]],
|
| 241 |
+
) -> List[int]:
|
| 242 |
+
if isinstance(class_labels, int):
|
| 243 |
+
return [class_labels]
|
| 244 |
+
|
| 245 |
+
if isinstance(class_labels, str):
|
| 246 |
+
return self.get_label_ids(class_labels)
|
| 247 |
+
|
| 248 |
+
if class_labels and isinstance(class_labels[0], str):
|
| 249 |
+
self._ensure_labels_loaded()
|
| 250 |
+
if all(label in self.labels for label in class_labels):
|
| 251 |
+
return self.get_label_ids(class_labels, lang="en")
|
| 252 |
+
if all(label in self.labels_cn for label in class_labels):
|
| 253 |
+
return self.get_label_ids(class_labels, lang="cn")
|
| 254 |
+
raise ValueError(
|
| 255 |
+
"Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` "
|
| 256 |
+
"or Chinese synonyms from `pipe.labels_cn`."
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
return list(class_labels)
|
| 260 |
+
|
| 261 |
+
def _predict_velocity(
|
| 262 |
+
self,
|
| 263 |
+
z_value: torch.Tensor,
|
| 264 |
+
t: torch.Tensor,
|
| 265 |
+
class_labels: torch.Tensor,
|
| 266 |
+
class_null: torch.Tensor,
|
| 267 |
+
do_classifier_free_guidance: bool,
|
| 268 |
+
guidance_scale: float,
|
| 269 |
+
guidance_interval_min: float,
|
| 270 |
+
guidance_interval_max: float,
|
| 271 |
+
) -> torch.Tensor:
|
| 272 |
+
t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype)
|
| 273 |
+
if do_classifier_free_guidance:
|
| 274 |
+
z_in = torch.cat([z_value, z_value], dim=0)
|
| 275 |
+
labels = torch.cat([class_labels, class_null], dim=0)
|
| 276 |
+
else:
|
| 277 |
+
z_in = z_value
|
| 278 |
+
labels = class_labels
|
| 279 |
+
|
| 280 |
+
t_batch = t.flatten().expand(z_in.shape[0])
|
| 281 |
+
x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample
|
| 282 |
+
v = self.scheduler.velocity_from_prediction(z_in, x_pred, t)
|
| 283 |
+
|
| 284 |
+
if not do_classifier_free_guidance:
|
| 285 |
+
return v
|
| 286 |
+
|
| 287 |
+
v_cond, v_uncond = v.chunk(2, dim=0)
|
| 288 |
+
interval_mask = t < guidance_interval_max
|
| 289 |
+
if guidance_interval_min != 0.0:
|
| 290 |
+
interval_mask = interval_mask & (t > guidance_interval_min)
|
| 291 |
+
scale = torch.where(
|
| 292 |
+
interval_mask,
|
| 293 |
+
torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype),
|
| 294 |
+
torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype),
|
| 295 |
+
)
|
| 296 |
+
return v_uncond + scale * (v_cond - v_uncond)
|
| 297 |
+
|
| 298 |
+
def _run_sampler(
|
| 299 |
+
self,
|
| 300 |
+
latents: torch.Tensor,
|
| 301 |
+
class_labels: torch.Tensor,
|
| 302 |
+
class_null: torch.Tensor,
|
| 303 |
+
num_inference_steps: int,
|
| 304 |
+
do_classifier_free_guidance: bool,
|
| 305 |
+
guidance_scale: float,
|
| 306 |
+
guidance_interval_min: float,
|
| 307 |
+
guidance_interval_max: float,
|
| 308 |
+
sampling_method: str,
|
| 309 |
+
) -> torch.Tensor:
|
| 310 |
+
device = latents.device
|
| 311 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method)
|
| 312 |
+
timesteps = self.scheduler.timesteps
|
| 313 |
+
|
| 314 |
+
for i in self.progress_bar(range(num_inference_steps - 1)):
|
| 315 |
+
t = timesteps[i]
|
| 316 |
+
t_next = timesteps[i + 1]
|
| 317 |
+
v = self._predict_velocity(
|
| 318 |
+
latents,
|
| 319 |
+
t,
|
| 320 |
+
class_labels,
|
| 321 |
+
class_null,
|
| 322 |
+
do_classifier_free_guidance,
|
| 323 |
+
guidance_scale,
|
| 324 |
+
guidance_interval_min,
|
| 325 |
+
guidance_interval_max,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
if sampling_method == "heun":
|
| 329 |
+
latents_euler = latents + (t_next - t) * v
|
| 330 |
+
v_next = self._predict_velocity(
|
| 331 |
+
latents_euler,
|
| 332 |
+
t_next,
|
| 333 |
+
class_labels,
|
| 334 |
+
class_null,
|
| 335 |
+
do_classifier_free_guidance,
|
| 336 |
+
guidance_scale,
|
| 337 |
+
guidance_interval_min,
|
| 338 |
+
guidance_interval_max,
|
| 339 |
+
)
|
| 340 |
+
latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample
|
| 341 |
+
else:
|
| 342 |
+
latents = self.scheduler.step(v, t, latents).prev_sample
|
| 343 |
+
|
| 344 |
+
t = timesteps[-2]
|
| 345 |
+
t_next = timesteps[-1]
|
| 346 |
+
v = self._predict_velocity(
|
| 347 |
+
latents,
|
| 348 |
+
t,
|
| 349 |
+
class_labels,
|
| 350 |
+
class_null,
|
| 351 |
+
do_classifier_free_guidance,
|
| 352 |
+
guidance_scale,
|
| 353 |
+
guidance_interval_min,
|
| 354 |
+
guidance_interval_max,
|
| 355 |
+
)
|
| 356 |
+
return latents + (t_next - t) * v
|
| 357 |
+
|
| 358 |
+
@torch.inference_mode()
|
| 359 |
+
def __call__(
|
| 360 |
+
self,
|
| 361 |
+
class_labels: Union[int, str, List[Union[int, str]]],
|
| 362 |
+
guidance_scale: Optional[float] = None,
|
| 363 |
+
guidance_interval_min: float = 0.1,
|
| 364 |
+
guidance_interval_max: float = 1.0,
|
| 365 |
+
noise_scale: Optional[float] = None,
|
| 366 |
+
t_eps: Optional[float] = None,
|
| 367 |
+
sampling_method: Optional[str] = None,
|
| 368 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 369 |
+
num_inference_steps: int = 50,
|
| 370 |
+
output_type: Optional[str] = "pil",
|
| 371 |
+
return_dict: bool = True,
|
| 372 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
| 373 |
+
r"""
|
| 374 |
+
Generate class-conditional images.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
class_labels (`int`, `str`, `list[int]`, or `list[str]`):
|
| 378 |
+
ImageNet class indices or human-readable label strings (English or Chinese).
|
| 379 |
+
guidance_scale (`float`, *optional*):
|
| 380 |
+
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
|
| 381 |
+
guidance_interval_min (`float`, defaults to `0.1`):
|
| 382 |
+
Lower bound of the CFG interval in flow time `t in [0, 1]`.
|
| 383 |
+
guidance_interval_max (`float`, defaults to `1.0`):
|
| 384 |
+
Upper bound of the CFG interval in flow time.
|
| 385 |
+
noise_scale (`float`, *optional*):
|
| 386 |
+
Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
|
| 387 |
+
t_eps (`float`, *optional*):
|
| 388 |
+
Epsilon clamp for the `1 - t` denominator (scheduler config by default).
|
| 389 |
+
sampling_method (`str`, *optional*):
|
| 390 |
+
`"heun"` or `"euler"`. Defaults to the scheduler config (`heun`).
|
| 391 |
+
generator (`torch.Generator`, *optional*):
|
| 392 |
+
RNG for reproducibility.
|
| 393 |
+
num_inference_steps (`int`, defaults to `50`):
|
| 394 |
+
Number of solver steps (at least 2).
|
| 395 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 396 |
+
`"pil"`, `"np"`, or `"pt"`.
|
| 397 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 398 |
+
Return [`ImagePipelineOutput`] if True.
|
| 399 |
+
"""
|
| 400 |
+
solver = sampling_method or self.scheduler.config.solver
|
| 401 |
+
if solver not in {"heun", "euler"}:
|
| 402 |
+
raise ValueError("sampling_method must be one of: 'heun', 'euler'.")
|
| 403 |
+
if num_inference_steps < 2:
|
| 404 |
+
raise ValueError("num_inference_steps must be >= 2.")
|
| 405 |
+
|
| 406 |
+
if t_eps is not None:
|
| 407 |
+
self.scheduler.register_to_config(t_eps=t_eps)
|
| 408 |
+
|
| 409 |
+
class_label_ids = self._normalize_class_labels(class_labels)
|
| 410 |
+
do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
|
| 411 |
+
|
| 412 |
+
batch_size = len(class_label_ids)
|
| 413 |
+
image_size = int(self.transformer.config.sample_size)
|
| 414 |
+
channels = int(self.transformer.config.in_channels)
|
| 415 |
+
null_class_val = int(self.transformer.config.num_classes)
|
| 416 |
+
|
| 417 |
+
if guidance_scale is None:
|
| 418 |
+
guidance_scale = 1.0
|
| 419 |
+
if noise_scale is None:
|
| 420 |
+
noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0)
|
| 421 |
+
|
| 422 |
+
latents = (
|
| 423 |
+
randn_tensor(
|
| 424 |
+
shape=(batch_size, channels, image_size, image_size),
|
| 425 |
+
generator=generator,
|
| 426 |
+
device=self._execution_device,
|
| 427 |
+
dtype=self.transformer.dtype,
|
| 428 |
+
)
|
| 429 |
+
* noise_scale
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
|
| 433 |
+
class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
|
| 434 |
+
class_null = torch.full_like(class_labels_t, null_class_val)
|
| 435 |
+
|
| 436 |
+
latents = self._run_sampler(
|
| 437 |
+
latents,
|
| 438 |
+
class_labels_t,
|
| 439 |
+
class_null,
|
| 440 |
+
num_inference_steps,
|
| 441 |
+
do_classifier_free_guidance,
|
| 442 |
+
guidance_scale,
|
| 443 |
+
guidance_interval_min,
|
| 444 |
+
guidance_interval_max,
|
| 445 |
+
solver,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
|
| 449 |
+
if output_type == "pt":
|
| 450 |
+
images = images_pt
|
| 451 |
+
elif output_type == "np":
|
| 452 |
+
images = images_pt.permute(0, 2, 3, 1).numpy()
|
| 453 |
+
else:
|
| 454 |
+
images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy())
|
| 455 |
+
|
| 456 |
+
self.maybe_free_model_hooks()
|
| 457 |
+
|
| 458 |
+
if not return_dict:
|
| 459 |
+
return (images,)
|
| 460 |
+
return ImagePipelineOutput(images=images)
|
JiT-B-32/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "JiTScheduler",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"num_train_timesteps": 1000,
|
| 5 |
+
"t_eps": 0.05,
|
| 6 |
+
"solver": "heun"
|
| 7 |
+
}
|
JiT-B-32/scheduler/scheduling_jit.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 21 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
| 22 |
+
from diffusers.utils import BaseOutput
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class JiTSchedulerOutput(BaseOutput):
|
| 27 |
+
"""
|
| 28 |
+
Output class for the JiT scheduler's `step` function.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
prev_sample (`torch.Tensor`):
|
| 32 |
+
Updated sample after one solver step along the JiT flow-time grid.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
prev_sample: torch.Tensor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class JiTScheduler(SchedulerMixin, ConfigMixin):
|
| 39 |
+
"""
|
| 40 |
+
Manual flow-matching scheduler for JiT checkpoints.
|
| 41 |
+
|
| 42 |
+
Uses a linear flow-time grid `t in [0, 1]` (increasing), matching the official JiT
|
| 43 |
+
sampler. Velocity is `v = (x_pred - z) / (1 - t)`; integration is explicit Euler or
|
| 44 |
+
Heun along that grid.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
order = 2
|
| 48 |
+
|
| 49 |
+
@register_to_config
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
num_train_timesteps: int = 1000,
|
| 53 |
+
t_eps: float = 5e-2,
|
| 54 |
+
solver: str = "heun",
|
| 55 |
+
):
|
| 56 |
+
if solver not in {"heun", "euler"}:
|
| 57 |
+
raise ValueError("solver must be one of: 'heun', 'euler'.")
|
| 58 |
+
self.timesteps: Optional[torch.Tensor] = None
|
| 59 |
+
self.sigmas: Optional[List[float]] = None
|
| 60 |
+
self.num_inference_steps: Optional[int] = None
|
| 61 |
+
self._step_index: Optional[int] = None
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def init_noise_sigma(self) -> float:
|
| 65 |
+
return 1.0
|
| 66 |
+
|
| 67 |
+
def set_timesteps(
|
| 68 |
+
self,
|
| 69 |
+
num_inference_steps: int,
|
| 70 |
+
device: Union[str, torch.device, None] = None,
|
| 71 |
+
solver: Optional[str] = None,
|
| 72 |
+
) -> None:
|
| 73 |
+
if num_inference_steps < 2:
|
| 74 |
+
raise ValueError("num_inference_steps must be >= 2.")
|
| 75 |
+
|
| 76 |
+
self.num_inference_steps = num_inference_steps
|
| 77 |
+
self.timesteps = torch.linspace(
|
| 78 |
+
0.0,
|
| 79 |
+
1.0,
|
| 80 |
+
num_inference_steps + 1,
|
| 81 |
+
device=device,
|
| 82 |
+
dtype=torch.float32,
|
| 83 |
+
)
|
| 84 |
+
sigma_grid = torch.linspace(0.0, 1.0, num_inference_steps, device=device, dtype=torch.float32)
|
| 85 |
+
self.sigmas = (1.0 - sigma_grid).tolist()
|
| 86 |
+
self._step_index = 0
|
| 87 |
+
if solver is not None:
|
| 88 |
+
self.register_to_config(solver=solver)
|
| 89 |
+
|
| 90 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
| 91 |
+
del timestep
|
| 92 |
+
return sample
|
| 93 |
+
|
| 94 |
+
def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int:
|
| 95 |
+
if self._step_index is not None:
|
| 96 |
+
return self._step_index
|
| 97 |
+
if self.timesteps is None:
|
| 98 |
+
raise ValueError("Call `set_timesteps` before `step`.")
|
| 99 |
+
if timestep is None:
|
| 100 |
+
return 0
|
| 101 |
+
t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0])
|
| 102 |
+
matches = (self.timesteps - t_value).abs() < 1e-6
|
| 103 |
+
if matches.any():
|
| 104 |
+
return int(matches.nonzero(as_tuple=False)[0].item())
|
| 105 |
+
return 0
|
| 106 |
+
|
| 107 |
+
def step(
|
| 108 |
+
self,
|
| 109 |
+
model_output: torch.Tensor,
|
| 110 |
+
timestep: Union[float, torch.Tensor, None],
|
| 111 |
+
sample: torch.Tensor,
|
| 112 |
+
model_output_next: Optional[torch.Tensor] = None,
|
| 113 |
+
return_dict: bool = True,
|
| 114 |
+
) -> Union[JiTSchedulerOutput, Tuple[torch.Tensor]]:
|
| 115 |
+
"""
|
| 116 |
+
Integrate one step on the linear `t` grid.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
model_output (`torch.Tensor`):
|
| 120 |
+
Velocity `v = (x_pred - z) / (1 - t)` at the current time.
|
| 121 |
+
timestep (`float` or `torch.Tensor`, *optional*):
|
| 122 |
+
Current flow time `t`. When omitted, uses the internal step index.
|
| 123 |
+
sample (`torch.Tensor`):
|
| 124 |
+
Current noisy latent `z`.
|
| 125 |
+
model_output_next (`torch.Tensor`, *optional*):
|
| 126 |
+
Velocity at `t_next` (required for Heun intermediate steps).
|
| 127 |
+
"""
|
| 128 |
+
if self.timesteps is None:
|
| 129 |
+
raise ValueError("Call `set_timesteps` before `step`.")
|
| 130 |
+
|
| 131 |
+
step_index = self._resolve_step_index(timestep)
|
| 132 |
+
if step_index >= len(self.timesteps) - 1:
|
| 133 |
+
raise ValueError("Scheduler has already reached the final timestep.")
|
| 134 |
+
|
| 135 |
+
t = self.timesteps[step_index]
|
| 136 |
+
t_next = self.timesteps[step_index + 1]
|
| 137 |
+
dt = t_next - t
|
| 138 |
+
|
| 139 |
+
if self.config.solver == "heun" and model_output_next is not None:
|
| 140 |
+
prev_sample = sample + dt * 0.5 * (model_output + model_output_next)
|
| 141 |
+
else:
|
| 142 |
+
prev_sample = sample + dt * model_output
|
| 143 |
+
|
| 144 |
+
self._step_index = step_index + 1
|
| 145 |
+
|
| 146 |
+
if not return_dict:
|
| 147 |
+
return (prev_sample,)
|
| 148 |
+
return JiTSchedulerOutput(prev_sample=prev_sample)
|
| 149 |
+
|
| 150 |
+
def velocity_from_prediction(
|
| 151 |
+
self,
|
| 152 |
+
sample: torch.Tensor,
|
| 153 |
+
x_pred: torch.Tensor,
|
| 154 |
+
timestep: Union[float, torch.Tensor],
|
| 155 |
+
) -> torch.Tensor:
|
| 156 |
+
"""Compute JiT velocity `v = (x_pred - z) / (1 - t)` with denominator clamp."""
|
| 157 |
+
t = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype)
|
| 158 |
+
while t.ndim < sample.ndim:
|
| 159 |
+
t = t.unsqueeze(-1)
|
| 160 |
+
denom = (1.0 - t).clamp_min(self.config.t_eps)
|
| 161 |
+
return (x_pred - sample) / denom
|
JiT-B-32/transformer/config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "JiTTransformer2DModel",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"attention_dropout": 0.0,
|
| 5 |
+
"bottleneck_dim": 128,
|
| 6 |
+
"dropout": 0.0,
|
| 7 |
+
"hidden_size": 768,
|
| 8 |
+
"in_channels": 3,
|
| 9 |
+
"in_context_len": 32,
|
| 10 |
+
"in_context_start": 4,
|
| 11 |
+
"mlp_ratio": 4.0,
|
| 12 |
+
"norm_eps": 1e-06,
|
| 13 |
+
"num_attention_heads": 12,
|
| 14 |
+
"num_classes": 1000,
|
| 15 |
+
"num_layers": 12,
|
| 16 |
+
"patch_size": 32,
|
| 17 |
+
"sample_size": 512
|
| 18 |
+
}
|
JiT-B-32/transformer/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:729654b3302fdae22eb4a4de9d2b24545828c82f2e2c8dcd3f5a01fe7c606ba4
|
| 3 |
+
size 533565560
|
JiT-B-32/transformer/jit_transformer_2d.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 24 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 25 |
+
from diffusers.models.normalization import RMSNorm
|
| 26 |
+
from diffusers.utils import logging
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def broadcat(tensors, dim=-1):
|
| 33 |
+
num_tensors = len(tensors)
|
| 34 |
+
shape_lens = {len(t.shape) for t in tensors}
|
| 35 |
+
if len(shape_lens) != 1:
|
| 36 |
+
raise ValueError("tensors must all have the same number of dimensions")
|
| 37 |
+
shape_len = list(shape_lens)[0]
|
| 38 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
| 39 |
+
dims = list(zip(*(list(t.shape) for t in tensors)))
|
| 40 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
| 41 |
+
|
| 42 |
+
if not all(len(set(t[1])) <= 2 for t in expandable_dims):
|
| 43 |
+
raise ValueError("invalid dimensions for broadcastable concatenation")
|
| 44 |
+
|
| 45 |
+
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
| 46 |
+
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
| 47 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
| 48 |
+
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
| 49 |
+
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
| 50 |
+
return torch.cat(tensors, dim=dim)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def rotate_half(x):
|
| 54 |
+
x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2)
|
| 55 |
+
x1, x2 = x.unbind(dim=-1)
|
| 56 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 57 |
+
return x.view(*x.shape[:-2], -1)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class JiTRotaryEmbedding(nn.Module):
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
dim,
|
| 64 |
+
pt_seq_len=16,
|
| 65 |
+
ft_seq_len=None,
|
| 66 |
+
custom_freqs=None,
|
| 67 |
+
theta=10000,
|
| 68 |
+
num_cls_token=0,
|
| 69 |
+
):
|
| 70 |
+
super().__init__()
|
| 71 |
+
if custom_freqs is not None:
|
| 72 |
+
freqs = custom_freqs
|
| 73 |
+
else:
|
| 74 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 75 |
+
|
| 76 |
+
if ft_seq_len is None:
|
| 77 |
+
ft_seq_len = pt_seq_len
|
| 78 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
| 79 |
+
|
| 80 |
+
freqs = torch.einsum("..., f -> ... f", t, freqs)
|
| 81 |
+
freqs = freqs.repeat_interleave(2, dim=-1)
|
| 82 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
|
| 83 |
+
|
| 84 |
+
if num_cls_token > 0:
|
| 85 |
+
freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D]
|
| 86 |
+
cos_img = freqs_flat.cos()
|
| 87 |
+
sin_img = freqs_flat.sin()
|
| 88 |
+
|
| 89 |
+
# prepend in-context cls token
|
| 90 |
+
_, D = cos_img.shape
|
| 91 |
+
cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype)
|
| 92 |
+
sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype)
|
| 93 |
+
|
| 94 |
+
self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False)
|
| 95 |
+
self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False)
|
| 96 |
+
else:
|
| 97 |
+
self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False)
|
| 98 |
+
self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False)
|
| 99 |
+
|
| 100 |
+
def forward(self, t):
|
| 101 |
+
# Applied on (batch, seq_len, heads, head_dim) tensors from attention.
|
| 102 |
+
seq_len = t.shape[1]
|
| 103 |
+
freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
|
| 104 |
+
freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
|
| 105 |
+
|
| 106 |
+
return t * freqs_cos[:, None, :] + rotate_half(t) * freqs_sin[:, None, :]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def modulate(x, shift, scale):
|
| 110 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class JiTPatchEmbed(nn.Module):
|
| 114 |
+
"""Image to Patch Embedding with Bottleneck"""
|
| 115 |
+
|
| 116 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True):
|
| 117 |
+
super().__init__()
|
| 118 |
+
img_size = (img_size, img_size)
|
| 119 |
+
patch_size = (patch_size, patch_size)
|
| 120 |
+
self.img_size = img_size
|
| 121 |
+
self.patch_size = patch_size
|
| 122 |
+
self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
| 123 |
+
|
| 124 |
+
self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False)
|
| 125 |
+
self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias)
|
| 126 |
+
|
| 127 |
+
def forward(self, x):
|
| 128 |
+
x = self.proj2(self.proj1(x)).flatten(2).transpose(1, 2)
|
| 129 |
+
return x
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class JiTTimestepEmbedder(nn.Module):
|
| 133 |
+
"""
|
| 134 |
+
Embeds scalar timesteps into vector representations.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.mlp = nn.Sequential(
|
| 140 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 141 |
+
nn.SiLU(),
|
| 142 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 143 |
+
)
|
| 144 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 145 |
+
|
| 146 |
+
@staticmethod
|
| 147 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 148 |
+
"""
|
| 149 |
+
Create sinusoidal timestep embeddings.
|
| 150 |
+
"""
|
| 151 |
+
half = dim // 2
|
| 152 |
+
freqs = torch.exp(
|
| 153 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 154 |
+
).to(device=t.device)
|
| 155 |
+
args = t[:, None].float() * freqs[None]
|
| 156 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 157 |
+
if dim % 2:
|
| 158 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 159 |
+
return embedding
|
| 160 |
+
|
| 161 |
+
def forward(self, t, dtype=None):
|
| 162 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 163 |
+
if dtype is not None:
|
| 164 |
+
t_freq = t_freq.to(dtype=dtype)
|
| 165 |
+
t_emb = self.mlp(t_freq)
|
| 166 |
+
return t_emb
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class JiTLabelEmbedder(nn.Module):
|
| 170 |
+
"""
|
| 171 |
+
Embeds class labels into vector representations.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(self, num_classes, hidden_size):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.embedding_table = nn.Embedding(num_classes + 1, hidden_size)
|
| 177 |
+
self.num_classes = num_classes
|
| 178 |
+
|
| 179 |
+
def forward(self, labels):
|
| 180 |
+
embeddings = self.embedding_table(labels)
|
| 181 |
+
return embeddings
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class JiTAttention(nn.Module):
|
| 185 |
+
def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.num_heads = num_heads
|
| 188 |
+
head_dim = dim // num_heads
|
| 189 |
+
|
| 190 |
+
self.q_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
|
| 191 |
+
self.k_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
|
| 192 |
+
|
| 193 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 194 |
+
self.attn_drop = attn_drop
|
| 195 |
+
self.proj = nn.Linear(dim, dim)
|
| 196 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 197 |
+
|
| 198 |
+
def forward(self, x, rope=None):
|
| 199 |
+
B, N, C = x.shape
|
| 200 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 201 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 202 |
+
|
| 203 |
+
q = self.q_norm(q)
|
| 204 |
+
k = self.k_norm(k)
|
| 205 |
+
|
| 206 |
+
if rope is not None:
|
| 207 |
+
q = q.transpose(1, 2)
|
| 208 |
+
k = k.transpose(1, 2)
|
| 209 |
+
q = rope(q)
|
| 210 |
+
k = rope(k)
|
| 211 |
+
q = q.transpose(1, 2)
|
| 212 |
+
k = k.transpose(1, 2)
|
| 213 |
+
|
| 214 |
+
dropout_p = self.attn_drop if self.training else 0.0
|
| 215 |
+
x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
| 216 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 217 |
+
x = self.proj(x)
|
| 218 |
+
x = self.proj_drop(x)
|
| 219 |
+
return x
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class JiTSwiGLUFFN(nn.Module):
|
| 223 |
+
def __init__(self, dim: int, hidden_dim: int, drop=0.0, bias=True) -> None:
|
| 224 |
+
super().__init__()
|
| 225 |
+
hidden_dim = int(hidden_dim * 2 / 3)
|
| 226 |
+
self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias)
|
| 227 |
+
self.w3 = nn.Linear(hidden_dim, dim, bias=bias)
|
| 228 |
+
self.ffn_dropout = nn.Dropout(drop)
|
| 229 |
+
|
| 230 |
+
def forward(self, x):
|
| 231 |
+
x12 = self.w12(x)
|
| 232 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 233 |
+
hidden = F.silu(x1) * x2
|
| 234 |
+
return self.w3(self.ffn_dropout(hidden))
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class JiTBlock(nn.Module):
|
| 238 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
|
| 239 |
+
super().__init__()
|
| 240 |
+
self.norm1 = RMSNorm(hidden_size, eps=eps)
|
| 241 |
+
self.attn = JiTAttention(
|
| 242 |
+
hidden_size,
|
| 243 |
+
num_heads=num_heads,
|
| 244 |
+
qkv_bias=True,
|
| 245 |
+
qk_norm=True,
|
| 246 |
+
attn_drop=attn_drop,
|
| 247 |
+
proj_drop=proj_drop,
|
| 248 |
+
eps=eps,
|
| 249 |
+
)
|
| 250 |
+
self.norm2 = RMSNorm(hidden_size, eps=eps)
|
| 251 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 252 |
+
self.mlp = JiTSwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop)
|
| 253 |
+
|
| 254 |
+
self.act = nn.SiLU()
|
| 255 |
+
self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 256 |
+
|
| 257 |
+
def forward(self, x, c, feat_rope=None):
|
| 258 |
+
# Apply activation
|
| 259 |
+
c = self.act(c)
|
| 260 |
+
|
| 261 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
| 262 |
+
|
| 263 |
+
# Attention block
|
| 264 |
+
norm_x = self.norm1(x)
|
| 265 |
+
modulated_x = modulate(norm_x, shift_msa, scale_msa)
|
| 266 |
+
attn_out = self.attn(modulated_x, rope=feat_rope)
|
| 267 |
+
x = x + gate_msa.unsqueeze(1) * attn_out
|
| 268 |
+
|
| 269 |
+
# MLP block
|
| 270 |
+
norm_x = self.norm2(x)
|
| 271 |
+
modulated_x = modulate(norm_x, shift_mlp, scale_mlp)
|
| 272 |
+
mlp_out = self.mlp(modulated_x)
|
| 273 |
+
x = x + gate_mlp.unsqueeze(1) * mlp_out
|
| 274 |
+
|
| 275 |
+
return x
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
| 279 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 280 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 281 |
+
grid = np.meshgrid(grid_w, grid_h)
|
| 282 |
+
grid = np.stack(grid, axis=0)
|
| 283 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 284 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 285 |
+
if cls_token and extra_tokens > 0:
|
| 286 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 287 |
+
return pos_embed
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 291 |
+
if embed_dim % 2 != 0:
|
| 292 |
+
raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
|
| 293 |
+
|
| 294 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
|
| 295 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
|
| 296 |
+
emb = np.concatenate([emb_h, emb_w], axis=1)
|
| 297 |
+
return emb
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 301 |
+
if embed_dim % 2 != 0:
|
| 302 |
+
raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
|
| 303 |
+
|
| 304 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 305 |
+
omega /= embed_dim / 2.0
|
| 306 |
+
omega = 1.0 / 10000**omega
|
| 307 |
+
|
| 308 |
+
pos = pos.reshape(-1)
|
| 309 |
+
out = np.einsum("m,d->md", pos, omega)
|
| 310 |
+
|
| 311 |
+
emb_sin = np.sin(out)
|
| 312 |
+
emb_cos = np.cos(out)
|
| 313 |
+
|
| 314 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1)
|
| 315 |
+
return emb
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class JiTTransformer2DModel(ModelMixin, ConfigMixin):
|
| 319 |
+
r"""
|
| 320 |
+
A 2D Transformer for pixel-space class-conditional generation with JiT
|
| 321 |
+
([Back to Basics: Let Denoising Generative Models Denoise](https://arxiv.org/abs/2511.13720)).
|
| 322 |
+
|
| 323 |
+
Parameters:
|
| 324 |
+
sample_size (`int`, defaults to `256`):
|
| 325 |
+
Input image resolution (height and width).
|
| 326 |
+
patch_size (`int`, defaults to `16`):
|
| 327 |
+
Patch size for the bottleneck patch embedder.
|
| 328 |
+
in_channels (`int`, defaults to `3`):
|
| 329 |
+
Number of input image channels.
|
| 330 |
+
hidden_size (`int`, defaults to `768`):
|
| 331 |
+
Transformer hidden dimension.
|
| 332 |
+
num_layers (`int`, defaults to `12`):
|
| 333 |
+
Number of JiT transformer blocks.
|
| 334 |
+
num_attention_heads (`int`, defaults to `12`):
|
| 335 |
+
Number of attention heads per block.
|
| 336 |
+
mlp_ratio (`float`, defaults to `4.0`):
|
| 337 |
+
MLP hidden dimension multiplier.
|
| 338 |
+
attention_dropout (`float`, defaults to `0.0`):
|
| 339 |
+
Attention dropout in the middle quarter of blocks.
|
| 340 |
+
dropout (`float`, defaults to `0.0`):
|
| 341 |
+
Projection dropout in the middle quarter of blocks.
|
| 342 |
+
num_classes (`int`, defaults to `1000`):
|
| 343 |
+
Number of class labels (null label uses index `num_classes` for CFG).
|
| 344 |
+
bottleneck_dim (`int`, defaults to `128`):
|
| 345 |
+
PCA bottleneck dimension in the patch embedder.
|
| 346 |
+
in_context_len (`int`, defaults to `32`):
|
| 347 |
+
Number of in-context class tokens prepended mid-network.
|
| 348 |
+
in_context_start (`int`, defaults to `4`):
|
| 349 |
+
Block index at which in-context tokens are inserted.
|
| 350 |
+
norm_eps (`float`, defaults to `1e-6`):
|
| 351 |
+
Epsilon for RMSNorm layers.
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
_supports_gradient_checkpointing = True
|
| 355 |
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
| 356 |
+
|
| 357 |
+
@register_to_config
|
| 358 |
+
def __init__(
|
| 359 |
+
self,
|
| 360 |
+
sample_size: int = 256,
|
| 361 |
+
patch_size: int = 16,
|
| 362 |
+
in_channels: int = 3,
|
| 363 |
+
hidden_size: int = 768,
|
| 364 |
+
num_layers: int = 12,
|
| 365 |
+
num_attention_heads: int = 12,
|
| 366 |
+
mlp_ratio: float = 4.0,
|
| 367 |
+
attention_dropout: float = 0.0,
|
| 368 |
+
dropout: float = 0.0,
|
| 369 |
+
num_classes: int = 1000,
|
| 370 |
+
bottleneck_dim: int = 128,
|
| 371 |
+
in_context_len: int = 32,
|
| 372 |
+
in_context_start: int = 4,
|
| 373 |
+
norm_eps: float = 1e-6,
|
| 374 |
+
):
|
| 375 |
+
super().__init__()
|
| 376 |
+
self.sample_size = sample_size
|
| 377 |
+
self.patch_size = patch_size
|
| 378 |
+
self.in_channels = in_channels
|
| 379 |
+
self.out_channels = in_channels
|
| 380 |
+
self.hidden_size = hidden_size
|
| 381 |
+
self.num_layers = num_layers
|
| 382 |
+
self.num_attention_heads = num_attention_heads
|
| 383 |
+
self.in_context_len = in_context_len
|
| 384 |
+
self.in_context_start = in_context_start
|
| 385 |
+
self.norm_eps = norm_eps
|
| 386 |
+
self.gradient_checkpointing = False
|
| 387 |
+
|
| 388 |
+
# Time and Class Embedding
|
| 389 |
+
self.t_embedder = JiTTimestepEmbedder(hidden_size)
|
| 390 |
+
self.y_embedder = JiTLabelEmbedder(num_classes, hidden_size)
|
| 391 |
+
|
| 392 |
+
# Patch Embedding
|
| 393 |
+
self.x_embedder = JiTPatchEmbed(
|
| 394 |
+
img_size=sample_size,
|
| 395 |
+
patch_size=patch_size,
|
| 396 |
+
in_chans=in_channels,
|
| 397 |
+
pca_dim=bottleneck_dim,
|
| 398 |
+
embed_dim=hidden_size,
|
| 399 |
+
bias=True,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# Positional Embedding (Fixed Sin-Cos)
|
| 403 |
+
num_patches = self.x_embedder.num_patches
|
| 404 |
+
pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches**0.5))
|
| 405 |
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
|
| 406 |
+
|
| 407 |
+
# In-context Embedding
|
| 408 |
+
if self.in_context_len > 0:
|
| 409 |
+
self.in_context_posemb = nn.Parameter(torch.zeros(1, self.in_context_len, hidden_size))
|
| 410 |
+
|
| 411 |
+
# RoPE
|
| 412 |
+
half_head_dim = hidden_size // num_attention_heads // 2
|
| 413 |
+
hw_seq_len = sample_size // patch_size
|
| 414 |
+
self.feat_rope = JiTRotaryEmbedding(dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=0)
|
| 415 |
+
self.feat_rope_incontext = JiTRotaryEmbedding(
|
| 416 |
+
dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=self.in_context_len
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
# Blocks
|
| 420 |
+
self.blocks = nn.ModuleList(
|
| 421 |
+
[
|
| 422 |
+
JiTBlock(
|
| 423 |
+
hidden_size,
|
| 424 |
+
num_attention_heads,
|
| 425 |
+
mlp_ratio=mlp_ratio,
|
| 426 |
+
attn_drop=attention_dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
|
| 427 |
+
proj_drop=dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
|
| 428 |
+
eps=norm_eps,
|
| 429 |
+
)
|
| 430 |
+
for i in range(num_layers)
|
| 431 |
+
]
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# Final Layer
|
| 435 |
+
self.norm_final = RMSNorm(hidden_size, eps=norm_eps)
|
| 436 |
+
self.linear_final = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
|
| 437 |
+
self.act_final = nn.SiLU()
|
| 438 |
+
self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 439 |
+
|
| 440 |
+
def forward(
|
| 441 |
+
self,
|
| 442 |
+
hidden_states: torch.Tensor,
|
| 443 |
+
timestep: torch.LongTensor,
|
| 444 |
+
class_labels: torch.LongTensor,
|
| 445 |
+
return_dict: bool = True,
|
| 446 |
+
):
|
| 447 |
+
|
| 448 |
+
t_emb = self.t_embedder(timestep, dtype=hidden_states.dtype)
|
| 449 |
+
y_emb = self.y_embedder(class_labels)
|
| 450 |
+
|
| 451 |
+
# Ensure embeddings match hidden_states dtype
|
| 452 |
+
y_emb = y_emb.to(dtype=hidden_states.dtype)
|
| 453 |
+
|
| 454 |
+
c = t_emb + y_emb
|
| 455 |
+
|
| 456 |
+
# Patch Embed
|
| 457 |
+
x = self.x_embedder(hidden_states)
|
| 458 |
+
x = x + self.pos_embed.to(x.dtype)
|
| 459 |
+
|
| 460 |
+
# Blocks
|
| 461 |
+
for i, block in enumerate(self.blocks):
|
| 462 |
+
if self.in_context_len > 0 and i == self.in_context_start:
|
| 463 |
+
in_context_tokens = y_emb.unsqueeze(1).repeat(1, self.in_context_len, 1)
|
| 464 |
+
in_context_tokens = in_context_tokens + self.in_context_posemb.to(in_context_tokens.dtype)
|
| 465 |
+
x = torch.cat([in_context_tokens, x], dim=1)
|
| 466 |
+
|
| 467 |
+
rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
|
| 468 |
+
|
| 469 |
+
if self.training and self.gradient_checkpointing:
|
| 470 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 471 |
+
block,
|
| 472 |
+
x,
|
| 473 |
+
c,
|
| 474 |
+
rope,
|
| 475 |
+
use_reentrant=False,
|
| 476 |
+
)
|
| 477 |
+
else:
|
| 478 |
+
x = block(x, c, feat_rope=rope)
|
| 479 |
+
|
| 480 |
+
# Slice off in-context tokens
|
| 481 |
+
if self.in_context_len > 0:
|
| 482 |
+
x = x[:, self.in_context_len :]
|
| 483 |
+
|
| 484 |
+
# Final Layer
|
| 485 |
+
c = self.act_final(c)
|
| 486 |
+
shift, scale = self.adaLN_modulation_final(c).chunk(2, dim=1)
|
| 487 |
+
|
| 488 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 489 |
+
x = self.linear_final(x)
|
| 490 |
+
|
| 491 |
+
# Unpatchify
|
| 492 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 493 |
+
x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels))
|
| 494 |
+
x = torch.einsum("nhwpqc->nchpwq", x)
|
| 495 |
+
output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size))
|
| 496 |
+
|
| 497 |
+
if not return_dict:
|
| 498 |
+
return (output,)
|
| 499 |
+
|
| 500 |
+
return Transformer2DModelOutput(sample=output)
|
JiT-H-16/model_index.json
CHANGED
|
@@ -1,8 +1,15 @@
|
|
| 1 |
{
|
| 2 |
-
"_class_name":
|
|
|
|
|
|
|
|
|
|
| 3 |
"_diffusers_version": "0.36.0",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"transformer": [
|
| 5 |
-
"
|
| 6 |
"JiTTransformer2DModel"
|
| 7 |
]
|
| 8 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"_class_name": [
|
| 3 |
+
"pipeline",
|
| 4 |
+
"JiTPipeline"
|
| 5 |
+
],
|
| 6 |
"_diffusers_version": "0.36.0",
|
| 7 |
+
"scheduler": [
|
| 8 |
+
"scheduling_jit",
|
| 9 |
+
"JiTScheduler"
|
| 10 |
+
],
|
| 11 |
"transformer": [
|
| 12 |
+
"jit_transformer_2d",
|
| 13 |
"JiTTransformer2DModel"
|
| 14 |
]
|
| 15 |
}
|
JiT-H-16/pipeline.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import importlib
|
| 18 |
+
import json
|
| 19 |
+
import sys
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 26 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
RECOMMENDED_NOISE_BY_SIZE = {
|
| 30 |
+
256: 1.0,
|
| 31 |
+
512: 2.0,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class JiTPipeline(DiffusionPipeline):
|
| 36 |
+
r"""
|
| 37 |
+
Pipeline for image generation using JiT (Just image Transformer).
|
| 38 |
+
|
| 39 |
+
Parameters:
|
| 40 |
+
transformer ([`JiTTransformer2DModel`]):
|
| 41 |
+
A class-conditioned `JiTTransformer2DModel` to denoise the images.
|
| 42 |
+
scheduler ([`JiTScheduler`]):
|
| 43 |
+
Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler).
|
| 44 |
+
id2label (`dict[int, str]`, *optional*):
|
| 45 |
+
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
|
| 46 |
+
id2label_cn (`dict[int, str]`, *optional*):
|
| 47 |
+
ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
model_cpu_offload_seq = "transformer"
|
| 51 |
+
|
| 52 |
+
@classmethod
|
| 53 |
+
def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
|
| 54 |
+
"""Load a self-contained variant folder locally or from the Hub.
|
| 55 |
+
|
| 56 |
+
Examples:
|
| 57 |
+
JiTPipeline.from_pretrained(".")
|
| 58 |
+
JiTPipeline.from_pretrained("./JiT-H-32")
|
| 59 |
+
DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
|
| 60 |
+
"""
|
| 61 |
+
repo_root = Path(__file__).resolve().parent
|
| 62 |
+
|
| 63 |
+
if pretrained_model_name_or_path in (None, "", "."):
|
| 64 |
+
variant = repo_root
|
| 65 |
+
elif (
|
| 66 |
+
isinstance(pretrained_model_name_or_path, str)
|
| 67 |
+
and "/" in pretrained_model_name_or_path
|
| 68 |
+
and not Path(pretrained_model_name_or_path).exists()
|
| 69 |
+
):
|
| 70 |
+
from huggingface_hub import snapshot_download
|
| 71 |
+
|
| 72 |
+
hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
|
| 73 |
+
if subfolder:
|
| 74 |
+
hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"])
|
| 75 |
+
cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
|
| 76 |
+
variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
|
| 77 |
+
else:
|
| 78 |
+
variant = Path(pretrained_model_name_or_path)
|
| 79 |
+
if not variant.is_absolute():
|
| 80 |
+
candidate = (Path.cwd() / variant).resolve()
|
| 81 |
+
variant = candidate if candidate.exists() else (repo_root / variant).resolve()
|
| 82 |
+
if subfolder:
|
| 83 |
+
variant = variant / subfolder
|
| 84 |
+
|
| 85 |
+
model_kwargs = dict(kwargs)
|
| 86 |
+
inserted: List[str] = []
|
| 87 |
+
|
| 88 |
+
def _load_component(folder: str, module_name: str, class_name: str):
|
| 89 |
+
comp_dir = variant / folder
|
| 90 |
+
module_path = comp_dir / f"{module_name}.py"
|
| 91 |
+
has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
|
| 92 |
+
if not module_path.exists() or not has_weights:
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
comp_path = str(comp_dir)
|
| 96 |
+
if comp_path not in sys.path:
|
| 97 |
+
sys.path.insert(0, comp_path)
|
| 98 |
+
inserted.append(comp_path)
|
| 99 |
+
|
| 100 |
+
module = importlib.import_module(module_name)
|
| 101 |
+
component_cls = getattr(module, class_name)
|
| 102 |
+
return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
|
| 106 |
+
scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler")
|
| 107 |
+
|
| 108 |
+
if transformer is None:
|
| 109 |
+
raise ValueError(f"No loadable transformer found under {variant}")
|
| 110 |
+
|
| 111 |
+
variant_path = str(variant)
|
| 112 |
+
id2label, id2label_cn = cls._load_labels_for_variant(variant_path)
|
| 113 |
+
|
| 114 |
+
pipe = cls(
|
| 115 |
+
transformer=transformer,
|
| 116 |
+
scheduler=scheduler,
|
| 117 |
+
id2label=id2label,
|
| 118 |
+
id2label_cn=id2label_cn,
|
| 119 |
+
)
|
| 120 |
+
if variant_path and hasattr(pipe, "register_to_config"):
|
| 121 |
+
pipe.register_to_config(_name_or_path=variant_path)
|
| 122 |
+
return pipe
|
| 123 |
+
finally:
|
| 124 |
+
for comp_path in inserted:
|
| 125 |
+
if comp_path in sys.path:
|
| 126 |
+
sys.path.remove(comp_path)
|
| 127 |
+
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
transformer,
|
| 131 |
+
scheduler,
|
| 132 |
+
id2label: Optional[Dict[int, str]] = None,
|
| 133 |
+
id2label_cn: Optional[Dict[int, str]] = None,
|
| 134 |
+
):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.register_modules(transformer=transformer, scheduler=scheduler)
|
| 137 |
+
|
| 138 |
+
self._id2label = id2label or {}
|
| 139 |
+
self._id2label_cn = id2label_cn or {}
|
| 140 |
+
self.labels = self._build_label2id(self._id2label)
|
| 141 |
+
self.labels_cn = self._build_label2id(self._id2label_cn)
|
| 142 |
+
|
| 143 |
+
def _ensure_labels_loaded(self) -> None:
|
| 144 |
+
if self._id2label or self._id2label_cn:
|
| 145 |
+
return
|
| 146 |
+
loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None))
|
| 147 |
+
if loaded_en:
|
| 148 |
+
self._id2label = loaded_en
|
| 149 |
+
self.labels = self._build_label2id(self._id2label)
|
| 150 |
+
if loaded_cn:
|
| 151 |
+
self._id2label_cn = loaded_cn
|
| 152 |
+
self.labels_cn = self._build_label2id(self._id2label_cn)
|
| 153 |
+
|
| 154 |
+
@staticmethod
|
| 155 |
+
def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]:
|
| 156 |
+
if not variant_path:
|
| 157 |
+
return None
|
| 158 |
+
variant_dir = Path(variant_path).resolve()
|
| 159 |
+
labels_dir = variant_dir.parent / "labels"
|
| 160 |
+
return labels_dir if labels_dir.is_dir() else None
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]:
|
| 164 |
+
filename = "id2label_en.json" if lang == "en" else "id2label_cn.json"
|
| 165 |
+
path = labels_dir / filename
|
| 166 |
+
if not path.exists():
|
| 167 |
+
raise FileNotFoundError(path)
|
| 168 |
+
raw = json.loads(path.read_text(encoding="utf-8"))
|
| 169 |
+
return {int(key): value for key, value in raw.items()}
|
| 170 |
+
|
| 171 |
+
@classmethod
|
| 172 |
+
def _load_labels_for_variant(
|
| 173 |
+
cls,
|
| 174 |
+
variant_path: Optional[str],
|
| 175 |
+
) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]:
|
| 176 |
+
labels_dir = cls._labels_dir_for_variant(variant_path)
|
| 177 |
+
if labels_dir is None:
|
| 178 |
+
return None, None
|
| 179 |
+
try:
|
| 180 |
+
return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn")
|
| 181 |
+
except FileNotFoundError:
|
| 182 |
+
return None, None
|
| 183 |
+
|
| 184 |
+
@staticmethod
|
| 185 |
+
def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
|
| 186 |
+
label2id: Dict[str, int] = {}
|
| 187 |
+
for class_id, value in id2label.items():
|
| 188 |
+
for synonym in value.split(","):
|
| 189 |
+
synonym = synonym.strip()
|
| 190 |
+
if synonym:
|
| 191 |
+
label2id[synonym] = int(class_id)
|
| 192 |
+
return dict(sorted(label2id.items()))
|
| 193 |
+
|
| 194 |
+
@property
|
| 195 |
+
def id2label(self) -> Dict[int, str]:
|
| 196 |
+
"""ImageNet class id to English label string (comma-separated synonyms)."""
|
| 197 |
+
self._ensure_labels_loaded()
|
| 198 |
+
return self._id2label
|
| 199 |
+
|
| 200 |
+
@property
|
| 201 |
+
def id2label_cn(self) -> Dict[int, str]:
|
| 202 |
+
"""ImageNet class id to Chinese label string (comma-separated synonyms)."""
|
| 203 |
+
self._ensure_labels_loaded()
|
| 204 |
+
return self._id2label_cn
|
| 205 |
+
|
| 206 |
+
def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]:
|
| 207 |
+
r"""
|
| 208 |
+
Map ImageNet label strings to class ids.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
label (`str` or `list[str]`):
|
| 212 |
+
One or more label strings. Each string must match a synonym in `id2label` (English)
|
| 213 |
+
or `id2label_cn` (Chinese).
|
| 214 |
+
lang (`str`, *optional*, defaults to `"en"`):
|
| 215 |
+
`"en"` uses English synonyms; `"cn"` uses Chinese synonyms.
|
| 216 |
+
"""
|
| 217 |
+
if lang not in ("en", "cn"):
|
| 218 |
+
raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.")
|
| 219 |
+
|
| 220 |
+
self._ensure_labels_loaded()
|
| 221 |
+
label2id = self.labels if lang == "en" else self.labels_cn
|
| 222 |
+
if not label2id:
|
| 223 |
+
raise ValueError(
|
| 224 |
+
f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if isinstance(label, str):
|
| 228 |
+
label = [label]
|
| 229 |
+
|
| 230 |
+
missing = [item for item in label if item not in label2id]
|
| 231 |
+
if missing:
|
| 232 |
+
preview = ", ".join(list(label2id.keys())[:8])
|
| 233 |
+
raise ValueError(
|
| 234 |
+
f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
|
| 235 |
+
)
|
| 236 |
+
return [label2id[item] for item in label]
|
| 237 |
+
|
| 238 |
+
def _normalize_class_labels(
|
| 239 |
+
self,
|
| 240 |
+
class_labels: Union[int, str, List[Union[int, str]]],
|
| 241 |
+
) -> List[int]:
|
| 242 |
+
if isinstance(class_labels, int):
|
| 243 |
+
return [class_labels]
|
| 244 |
+
|
| 245 |
+
if isinstance(class_labels, str):
|
| 246 |
+
return self.get_label_ids(class_labels)
|
| 247 |
+
|
| 248 |
+
if class_labels and isinstance(class_labels[0], str):
|
| 249 |
+
self._ensure_labels_loaded()
|
| 250 |
+
if all(label in self.labels for label in class_labels):
|
| 251 |
+
return self.get_label_ids(class_labels, lang="en")
|
| 252 |
+
if all(label in self.labels_cn for label in class_labels):
|
| 253 |
+
return self.get_label_ids(class_labels, lang="cn")
|
| 254 |
+
raise ValueError(
|
| 255 |
+
"Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` "
|
| 256 |
+
"or Chinese synonyms from `pipe.labels_cn`."
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
return list(class_labels)
|
| 260 |
+
|
| 261 |
+
def _predict_velocity(
|
| 262 |
+
self,
|
| 263 |
+
z_value: torch.Tensor,
|
| 264 |
+
t: torch.Tensor,
|
| 265 |
+
class_labels: torch.Tensor,
|
| 266 |
+
class_null: torch.Tensor,
|
| 267 |
+
do_classifier_free_guidance: bool,
|
| 268 |
+
guidance_scale: float,
|
| 269 |
+
guidance_interval_min: float,
|
| 270 |
+
guidance_interval_max: float,
|
| 271 |
+
) -> torch.Tensor:
|
| 272 |
+
t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype)
|
| 273 |
+
if do_classifier_free_guidance:
|
| 274 |
+
z_in = torch.cat([z_value, z_value], dim=0)
|
| 275 |
+
labels = torch.cat([class_labels, class_null], dim=0)
|
| 276 |
+
else:
|
| 277 |
+
z_in = z_value
|
| 278 |
+
labels = class_labels
|
| 279 |
+
|
| 280 |
+
t_batch = t.flatten().expand(z_in.shape[0])
|
| 281 |
+
x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample
|
| 282 |
+
v = self.scheduler.velocity_from_prediction(z_in, x_pred, t)
|
| 283 |
+
|
| 284 |
+
if not do_classifier_free_guidance:
|
| 285 |
+
return v
|
| 286 |
+
|
| 287 |
+
v_cond, v_uncond = v.chunk(2, dim=0)
|
| 288 |
+
interval_mask = t < guidance_interval_max
|
| 289 |
+
if guidance_interval_min != 0.0:
|
| 290 |
+
interval_mask = interval_mask & (t > guidance_interval_min)
|
| 291 |
+
scale = torch.where(
|
| 292 |
+
interval_mask,
|
| 293 |
+
torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype),
|
| 294 |
+
torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype),
|
| 295 |
+
)
|
| 296 |
+
return v_uncond + scale * (v_cond - v_uncond)
|
| 297 |
+
|
| 298 |
+
def _run_sampler(
|
| 299 |
+
self,
|
| 300 |
+
latents: torch.Tensor,
|
| 301 |
+
class_labels: torch.Tensor,
|
| 302 |
+
class_null: torch.Tensor,
|
| 303 |
+
num_inference_steps: int,
|
| 304 |
+
do_classifier_free_guidance: bool,
|
| 305 |
+
guidance_scale: float,
|
| 306 |
+
guidance_interval_min: float,
|
| 307 |
+
guidance_interval_max: float,
|
| 308 |
+
sampling_method: str,
|
| 309 |
+
) -> torch.Tensor:
|
| 310 |
+
device = latents.device
|
| 311 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method)
|
| 312 |
+
timesteps = self.scheduler.timesteps
|
| 313 |
+
|
| 314 |
+
for i in self.progress_bar(range(num_inference_steps - 1)):
|
| 315 |
+
t = timesteps[i]
|
| 316 |
+
t_next = timesteps[i + 1]
|
| 317 |
+
v = self._predict_velocity(
|
| 318 |
+
latents,
|
| 319 |
+
t,
|
| 320 |
+
class_labels,
|
| 321 |
+
class_null,
|
| 322 |
+
do_classifier_free_guidance,
|
| 323 |
+
guidance_scale,
|
| 324 |
+
guidance_interval_min,
|
| 325 |
+
guidance_interval_max,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
if sampling_method == "heun":
|
| 329 |
+
latents_euler = latents + (t_next - t) * v
|
| 330 |
+
v_next = self._predict_velocity(
|
| 331 |
+
latents_euler,
|
| 332 |
+
t_next,
|
| 333 |
+
class_labels,
|
| 334 |
+
class_null,
|
| 335 |
+
do_classifier_free_guidance,
|
| 336 |
+
guidance_scale,
|
| 337 |
+
guidance_interval_min,
|
| 338 |
+
guidance_interval_max,
|
| 339 |
+
)
|
| 340 |
+
latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample
|
| 341 |
+
else:
|
| 342 |
+
latents = self.scheduler.step(v, t, latents).prev_sample
|
| 343 |
+
|
| 344 |
+
t = timesteps[-2]
|
| 345 |
+
t_next = timesteps[-1]
|
| 346 |
+
v = self._predict_velocity(
|
| 347 |
+
latents,
|
| 348 |
+
t,
|
| 349 |
+
class_labels,
|
| 350 |
+
class_null,
|
| 351 |
+
do_classifier_free_guidance,
|
| 352 |
+
guidance_scale,
|
| 353 |
+
guidance_interval_min,
|
| 354 |
+
guidance_interval_max,
|
| 355 |
+
)
|
| 356 |
+
return latents + (t_next - t) * v
|
| 357 |
+
|
| 358 |
+
@torch.inference_mode()
|
| 359 |
+
def __call__(
|
| 360 |
+
self,
|
| 361 |
+
class_labels: Union[int, str, List[Union[int, str]]],
|
| 362 |
+
guidance_scale: Optional[float] = None,
|
| 363 |
+
guidance_interval_min: float = 0.1,
|
| 364 |
+
guidance_interval_max: float = 1.0,
|
| 365 |
+
noise_scale: Optional[float] = None,
|
| 366 |
+
t_eps: Optional[float] = None,
|
| 367 |
+
sampling_method: Optional[str] = None,
|
| 368 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 369 |
+
num_inference_steps: int = 50,
|
| 370 |
+
output_type: Optional[str] = "pil",
|
| 371 |
+
return_dict: bool = True,
|
| 372 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
| 373 |
+
r"""
|
| 374 |
+
Generate class-conditional images.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
class_labels (`int`, `str`, `list[int]`, or `list[str]`):
|
| 378 |
+
ImageNet class indices or human-readable label strings (English or Chinese).
|
| 379 |
+
guidance_scale (`float`, *optional*):
|
| 380 |
+
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
|
| 381 |
+
guidance_interval_min (`float`, defaults to `0.1`):
|
| 382 |
+
Lower bound of the CFG interval in flow time `t in [0, 1]`.
|
| 383 |
+
guidance_interval_max (`float`, defaults to `1.0`):
|
| 384 |
+
Upper bound of the CFG interval in flow time.
|
| 385 |
+
noise_scale (`float`, *optional*):
|
| 386 |
+
Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
|
| 387 |
+
t_eps (`float`, *optional*):
|
| 388 |
+
Epsilon clamp for the `1 - t` denominator (scheduler config by default).
|
| 389 |
+
sampling_method (`str`, *optional*):
|
| 390 |
+
`"heun"` or `"euler"`. Defaults to the scheduler config (`heun`).
|
| 391 |
+
generator (`torch.Generator`, *optional*):
|
| 392 |
+
RNG for reproducibility.
|
| 393 |
+
num_inference_steps (`int`, defaults to `50`):
|
| 394 |
+
Number of solver steps (at least 2).
|
| 395 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 396 |
+
`"pil"`, `"np"`, or `"pt"`.
|
| 397 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 398 |
+
Return [`ImagePipelineOutput`] if True.
|
| 399 |
+
"""
|
| 400 |
+
solver = sampling_method or self.scheduler.config.solver
|
| 401 |
+
if solver not in {"heun", "euler"}:
|
| 402 |
+
raise ValueError("sampling_method must be one of: 'heun', 'euler'.")
|
| 403 |
+
if num_inference_steps < 2:
|
| 404 |
+
raise ValueError("num_inference_steps must be >= 2.")
|
| 405 |
+
|
| 406 |
+
if t_eps is not None:
|
| 407 |
+
self.scheduler.register_to_config(t_eps=t_eps)
|
| 408 |
+
|
| 409 |
+
class_label_ids = self._normalize_class_labels(class_labels)
|
| 410 |
+
do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
|
| 411 |
+
|
| 412 |
+
batch_size = len(class_label_ids)
|
| 413 |
+
image_size = int(self.transformer.config.sample_size)
|
| 414 |
+
channels = int(self.transformer.config.in_channels)
|
| 415 |
+
null_class_val = int(self.transformer.config.num_classes)
|
| 416 |
+
|
| 417 |
+
if guidance_scale is None:
|
| 418 |
+
guidance_scale = 1.0
|
| 419 |
+
if noise_scale is None:
|
| 420 |
+
noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0)
|
| 421 |
+
|
| 422 |
+
latents = (
|
| 423 |
+
randn_tensor(
|
| 424 |
+
shape=(batch_size, channels, image_size, image_size),
|
| 425 |
+
generator=generator,
|
| 426 |
+
device=self._execution_device,
|
| 427 |
+
dtype=self.transformer.dtype,
|
| 428 |
+
)
|
| 429 |
+
* noise_scale
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
|
| 433 |
+
class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
|
| 434 |
+
class_null = torch.full_like(class_labels_t, null_class_val)
|
| 435 |
+
|
| 436 |
+
latents = self._run_sampler(
|
| 437 |
+
latents,
|
| 438 |
+
class_labels_t,
|
| 439 |
+
class_null,
|
| 440 |
+
num_inference_steps,
|
| 441 |
+
do_classifier_free_guidance,
|
| 442 |
+
guidance_scale,
|
| 443 |
+
guidance_interval_min,
|
| 444 |
+
guidance_interval_max,
|
| 445 |
+
solver,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
|
| 449 |
+
if output_type == "pt":
|
| 450 |
+
images = images_pt
|
| 451 |
+
elif output_type == "np":
|
| 452 |
+
images = images_pt.permute(0, 2, 3, 1).numpy()
|
| 453 |
+
else:
|
| 454 |
+
images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy())
|
| 455 |
+
|
| 456 |
+
self.maybe_free_model_hooks()
|
| 457 |
+
|
| 458 |
+
if not return_dict:
|
| 459 |
+
return (images,)
|
| 460 |
+
return ImagePipelineOutput(images=images)
|
JiT-H-16/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "JiTScheduler",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"num_train_timesteps": 1000,
|
| 5 |
+
"t_eps": 0.05,
|
| 6 |
+
"solver": "heun"
|
| 7 |
+
}
|
JiT-H-16/scheduler/scheduling_jit.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 21 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
| 22 |
+
from diffusers.utils import BaseOutput
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class JiTSchedulerOutput(BaseOutput):
|
| 27 |
+
"""
|
| 28 |
+
Output class for the JiT scheduler's `step` function.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
prev_sample (`torch.Tensor`):
|
| 32 |
+
Updated sample after one solver step along the JiT flow-time grid.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
prev_sample: torch.Tensor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class JiTScheduler(SchedulerMixin, ConfigMixin):
|
| 39 |
+
"""
|
| 40 |
+
Manual flow-matching scheduler for JiT checkpoints.
|
| 41 |
+
|
| 42 |
+
Uses a linear flow-time grid `t in [0, 1]` (increasing), matching the official JiT
|
| 43 |
+
sampler. Velocity is `v = (x_pred - z) / (1 - t)`; integration is explicit Euler or
|
| 44 |
+
Heun along that grid.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
order = 2
|
| 48 |
+
|
| 49 |
+
@register_to_config
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
num_train_timesteps: int = 1000,
|
| 53 |
+
t_eps: float = 5e-2,
|
| 54 |
+
solver: str = "heun",
|
| 55 |
+
):
|
| 56 |
+
if solver not in {"heun", "euler"}:
|
| 57 |
+
raise ValueError("solver must be one of: 'heun', 'euler'.")
|
| 58 |
+
self.timesteps: Optional[torch.Tensor] = None
|
| 59 |
+
self.sigmas: Optional[List[float]] = None
|
| 60 |
+
self.num_inference_steps: Optional[int] = None
|
| 61 |
+
self._step_index: Optional[int] = None
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def init_noise_sigma(self) -> float:
|
| 65 |
+
return 1.0
|
| 66 |
+
|
| 67 |
+
def set_timesteps(
|
| 68 |
+
self,
|
| 69 |
+
num_inference_steps: int,
|
| 70 |
+
device: Union[str, torch.device, None] = None,
|
| 71 |
+
solver: Optional[str] = None,
|
| 72 |
+
) -> None:
|
| 73 |
+
if num_inference_steps < 2:
|
| 74 |
+
raise ValueError("num_inference_steps must be >= 2.")
|
| 75 |
+
|
| 76 |
+
self.num_inference_steps = num_inference_steps
|
| 77 |
+
self.timesteps = torch.linspace(
|
| 78 |
+
0.0,
|
| 79 |
+
1.0,
|
| 80 |
+
num_inference_steps + 1,
|
| 81 |
+
device=device,
|
| 82 |
+
dtype=torch.float32,
|
| 83 |
+
)
|
| 84 |
+
sigma_grid = torch.linspace(0.0, 1.0, num_inference_steps, device=device, dtype=torch.float32)
|
| 85 |
+
self.sigmas = (1.0 - sigma_grid).tolist()
|
| 86 |
+
self._step_index = 0
|
| 87 |
+
if solver is not None:
|
| 88 |
+
self.register_to_config(solver=solver)
|
| 89 |
+
|
| 90 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
| 91 |
+
del timestep
|
| 92 |
+
return sample
|
| 93 |
+
|
| 94 |
+
def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int:
|
| 95 |
+
if self._step_index is not None:
|
| 96 |
+
return self._step_index
|
| 97 |
+
if self.timesteps is None:
|
| 98 |
+
raise ValueError("Call `set_timesteps` before `step`.")
|
| 99 |
+
if timestep is None:
|
| 100 |
+
return 0
|
| 101 |
+
t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0])
|
| 102 |
+
matches = (self.timesteps - t_value).abs() < 1e-6
|
| 103 |
+
if matches.any():
|
| 104 |
+
return int(matches.nonzero(as_tuple=False)[0].item())
|
| 105 |
+
return 0
|
| 106 |
+
|
| 107 |
+
def step(
|
| 108 |
+
self,
|
| 109 |
+
model_output: torch.Tensor,
|
| 110 |
+
timestep: Union[float, torch.Tensor, None],
|
| 111 |
+
sample: torch.Tensor,
|
| 112 |
+
model_output_next: Optional[torch.Tensor] = None,
|
| 113 |
+
return_dict: bool = True,
|
| 114 |
+
) -> Union[JiTSchedulerOutput, Tuple[torch.Tensor]]:
|
| 115 |
+
"""
|
| 116 |
+
Integrate one step on the linear `t` grid.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
model_output (`torch.Tensor`):
|
| 120 |
+
Velocity `v = (x_pred - z) / (1 - t)` at the current time.
|
| 121 |
+
timestep (`float` or `torch.Tensor`, *optional*):
|
| 122 |
+
Current flow time `t`. When omitted, uses the internal step index.
|
| 123 |
+
sample (`torch.Tensor`):
|
| 124 |
+
Current noisy latent `z`.
|
| 125 |
+
model_output_next (`torch.Tensor`, *optional*):
|
| 126 |
+
Velocity at `t_next` (required for Heun intermediate steps).
|
| 127 |
+
"""
|
| 128 |
+
if self.timesteps is None:
|
| 129 |
+
raise ValueError("Call `set_timesteps` before `step`.")
|
| 130 |
+
|
| 131 |
+
step_index = self._resolve_step_index(timestep)
|
| 132 |
+
if step_index >= len(self.timesteps) - 1:
|
| 133 |
+
raise ValueError("Scheduler has already reached the final timestep.")
|
| 134 |
+
|
| 135 |
+
t = self.timesteps[step_index]
|
| 136 |
+
t_next = self.timesteps[step_index + 1]
|
| 137 |
+
dt = t_next - t
|
| 138 |
+
|
| 139 |
+
if self.config.solver == "heun" and model_output_next is not None:
|
| 140 |
+
prev_sample = sample + dt * 0.5 * (model_output + model_output_next)
|
| 141 |
+
else:
|
| 142 |
+
prev_sample = sample + dt * model_output
|
| 143 |
+
|
| 144 |
+
self._step_index = step_index + 1
|
| 145 |
+
|
| 146 |
+
if not return_dict:
|
| 147 |
+
return (prev_sample,)
|
| 148 |
+
return JiTSchedulerOutput(prev_sample=prev_sample)
|
| 149 |
+
|
| 150 |
+
def velocity_from_prediction(
|
| 151 |
+
self,
|
| 152 |
+
sample: torch.Tensor,
|
| 153 |
+
x_pred: torch.Tensor,
|
| 154 |
+
timestep: Union[float, torch.Tensor],
|
| 155 |
+
) -> torch.Tensor:
|
| 156 |
+
"""Compute JiT velocity `v = (x_pred - z) / (1 - t)` with denominator clamp."""
|
| 157 |
+
t = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype)
|
| 158 |
+
while t.ndim < sample.ndim:
|
| 159 |
+
t = t.unsqueeze(-1)
|
| 160 |
+
denom = (1.0 - t).clamp_min(self.config.t_eps)
|
| 161 |
+
return (x_pred - sample) / denom
|
JiT-H-16/transformer/config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "JiTTransformer2DModel",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"attention_dropout": 0.0,
|
| 5 |
+
"bottleneck_dim": 256,
|
| 6 |
+
"dropout": 0.2,
|
| 7 |
+
"hidden_size": 1280,
|
| 8 |
+
"in_channels": 3,
|
| 9 |
+
"in_context_len": 32,
|
| 10 |
+
"in_context_start": 10,
|
| 11 |
+
"mlp_ratio": 4.0,
|
| 12 |
+
"norm_eps": 1e-06,
|
| 13 |
+
"num_attention_heads": 16,
|
| 14 |
+
"num_classes": 1000,
|
| 15 |
+
"num_layers": 32,
|
| 16 |
+
"patch_size": 16,
|
| 17 |
+
"sample_size": 256
|
| 18 |
+
}
|
JiT-H-16/transformer/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b6ad4cf51f5ff385db58573a23353b50df4be7a63dd50bdc7b57af404e7b68e7
|
| 3 |
+
size 3811413928
|
JiT-H-16/transformer/jit_transformer_2d.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 24 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 25 |
+
from diffusers.models.normalization import RMSNorm
|
| 26 |
+
from diffusers.utils import logging
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def broadcat(tensors, dim=-1):
|
| 33 |
+
num_tensors = len(tensors)
|
| 34 |
+
shape_lens = {len(t.shape) for t in tensors}
|
| 35 |
+
if len(shape_lens) != 1:
|
| 36 |
+
raise ValueError("tensors must all have the same number of dimensions")
|
| 37 |
+
shape_len = list(shape_lens)[0]
|
| 38 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
| 39 |
+
dims = list(zip(*(list(t.shape) for t in tensors)))
|
| 40 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
| 41 |
+
|
| 42 |
+
if not all(len(set(t[1])) <= 2 for t in expandable_dims):
|
| 43 |
+
raise ValueError("invalid dimensions for broadcastable concatenation")
|
| 44 |
+
|
| 45 |
+
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
| 46 |
+
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
| 47 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
| 48 |
+
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
| 49 |
+
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
| 50 |
+
return torch.cat(tensors, dim=dim)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def rotate_half(x):
|
| 54 |
+
x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2)
|
| 55 |
+
x1, x2 = x.unbind(dim=-1)
|
| 56 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 57 |
+
return x.view(*x.shape[:-2], -1)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class JiTRotaryEmbedding(nn.Module):
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
dim,
|
| 64 |
+
pt_seq_len=16,
|
| 65 |
+
ft_seq_len=None,
|
| 66 |
+
custom_freqs=None,
|
| 67 |
+
theta=10000,
|
| 68 |
+
num_cls_token=0,
|
| 69 |
+
):
|
| 70 |
+
super().__init__()
|
| 71 |
+
if custom_freqs is not None:
|
| 72 |
+
freqs = custom_freqs
|
| 73 |
+
else:
|
| 74 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 75 |
+
|
| 76 |
+
if ft_seq_len is None:
|
| 77 |
+
ft_seq_len = pt_seq_len
|
| 78 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
| 79 |
+
|
| 80 |
+
freqs = torch.einsum("..., f -> ... f", t, freqs)
|
| 81 |
+
freqs = freqs.repeat_interleave(2, dim=-1)
|
| 82 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
|
| 83 |
+
|
| 84 |
+
if num_cls_token > 0:
|
| 85 |
+
freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D]
|
| 86 |
+
cos_img = freqs_flat.cos()
|
| 87 |
+
sin_img = freqs_flat.sin()
|
| 88 |
+
|
| 89 |
+
# prepend in-context cls token
|
| 90 |
+
_, D = cos_img.shape
|
| 91 |
+
cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype)
|
| 92 |
+
sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype)
|
| 93 |
+
|
| 94 |
+
self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False)
|
| 95 |
+
self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False)
|
| 96 |
+
else:
|
| 97 |
+
self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False)
|
| 98 |
+
self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False)
|
| 99 |
+
|
| 100 |
+
def forward(self, t):
|
| 101 |
+
# Applied on (batch, seq_len, heads, head_dim) tensors from attention.
|
| 102 |
+
seq_len = t.shape[1]
|
| 103 |
+
freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
|
| 104 |
+
freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
|
| 105 |
+
|
| 106 |
+
return t * freqs_cos[:, None, :] + rotate_half(t) * freqs_sin[:, None, :]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def modulate(x, shift, scale):
|
| 110 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class JiTPatchEmbed(nn.Module):
|
| 114 |
+
"""Image to Patch Embedding with Bottleneck"""
|
| 115 |
+
|
| 116 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True):
|
| 117 |
+
super().__init__()
|
| 118 |
+
img_size = (img_size, img_size)
|
| 119 |
+
patch_size = (patch_size, patch_size)
|
| 120 |
+
self.img_size = img_size
|
| 121 |
+
self.patch_size = patch_size
|
| 122 |
+
self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
| 123 |
+
|
| 124 |
+
self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False)
|
| 125 |
+
self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias)
|
| 126 |
+
|
| 127 |
+
def forward(self, x):
|
| 128 |
+
x = self.proj2(self.proj1(x)).flatten(2).transpose(1, 2)
|
| 129 |
+
return x
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class JiTTimestepEmbedder(nn.Module):
|
| 133 |
+
"""
|
| 134 |
+
Embeds scalar timesteps into vector representations.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.mlp = nn.Sequential(
|
| 140 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 141 |
+
nn.SiLU(),
|
| 142 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 143 |
+
)
|
| 144 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 145 |
+
|
| 146 |
+
@staticmethod
|
| 147 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 148 |
+
"""
|
| 149 |
+
Create sinusoidal timestep embeddings.
|
| 150 |
+
"""
|
| 151 |
+
half = dim // 2
|
| 152 |
+
freqs = torch.exp(
|
| 153 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 154 |
+
).to(device=t.device)
|
| 155 |
+
args = t[:, None].float() * freqs[None]
|
| 156 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 157 |
+
if dim % 2:
|
| 158 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 159 |
+
return embedding
|
| 160 |
+
|
| 161 |
+
def forward(self, t, dtype=None):
|
| 162 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 163 |
+
if dtype is not None:
|
| 164 |
+
t_freq = t_freq.to(dtype=dtype)
|
| 165 |
+
t_emb = self.mlp(t_freq)
|
| 166 |
+
return t_emb
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class JiTLabelEmbedder(nn.Module):
|
| 170 |
+
"""
|
| 171 |
+
Embeds class labels into vector representations.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(self, num_classes, hidden_size):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.embedding_table = nn.Embedding(num_classes + 1, hidden_size)
|
| 177 |
+
self.num_classes = num_classes
|
| 178 |
+
|
| 179 |
+
def forward(self, labels):
|
| 180 |
+
embeddings = self.embedding_table(labels)
|
| 181 |
+
return embeddings
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class JiTAttention(nn.Module):
|
| 185 |
+
def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.num_heads = num_heads
|
| 188 |
+
head_dim = dim // num_heads
|
| 189 |
+
|
| 190 |
+
self.q_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
|
| 191 |
+
self.k_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
|
| 192 |
+
|
| 193 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 194 |
+
self.attn_drop = attn_drop
|
| 195 |
+
self.proj = nn.Linear(dim, dim)
|
| 196 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 197 |
+
|
| 198 |
+
def forward(self, x, rope=None):
|
| 199 |
+
B, N, C = x.shape
|
| 200 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 201 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 202 |
+
|
| 203 |
+
q = self.q_norm(q)
|
| 204 |
+
k = self.k_norm(k)
|
| 205 |
+
|
| 206 |
+
if rope is not None:
|
| 207 |
+
q = q.transpose(1, 2)
|
| 208 |
+
k = k.transpose(1, 2)
|
| 209 |
+
q = rope(q)
|
| 210 |
+
k = rope(k)
|
| 211 |
+
q = q.transpose(1, 2)
|
| 212 |
+
k = k.transpose(1, 2)
|
| 213 |
+
|
| 214 |
+
dropout_p = self.attn_drop if self.training else 0.0
|
| 215 |
+
x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
| 216 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 217 |
+
x = self.proj(x)
|
| 218 |
+
x = self.proj_drop(x)
|
| 219 |
+
return x
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class JiTSwiGLUFFN(nn.Module):
|
| 223 |
+
def __init__(self, dim: int, hidden_dim: int, drop=0.0, bias=True) -> None:
|
| 224 |
+
super().__init__()
|
| 225 |
+
hidden_dim = int(hidden_dim * 2 / 3)
|
| 226 |
+
self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias)
|
| 227 |
+
self.w3 = nn.Linear(hidden_dim, dim, bias=bias)
|
| 228 |
+
self.ffn_dropout = nn.Dropout(drop)
|
| 229 |
+
|
| 230 |
+
def forward(self, x):
|
| 231 |
+
x12 = self.w12(x)
|
| 232 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 233 |
+
hidden = F.silu(x1) * x2
|
| 234 |
+
return self.w3(self.ffn_dropout(hidden))
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class JiTBlock(nn.Module):
|
| 238 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
|
| 239 |
+
super().__init__()
|
| 240 |
+
self.norm1 = RMSNorm(hidden_size, eps=eps)
|
| 241 |
+
self.attn = JiTAttention(
|
| 242 |
+
hidden_size,
|
| 243 |
+
num_heads=num_heads,
|
| 244 |
+
qkv_bias=True,
|
| 245 |
+
qk_norm=True,
|
| 246 |
+
attn_drop=attn_drop,
|
| 247 |
+
proj_drop=proj_drop,
|
| 248 |
+
eps=eps,
|
| 249 |
+
)
|
| 250 |
+
self.norm2 = RMSNorm(hidden_size, eps=eps)
|
| 251 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 252 |
+
self.mlp = JiTSwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop)
|
| 253 |
+
|
| 254 |
+
self.act = nn.SiLU()
|
| 255 |
+
self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 256 |
+
|
| 257 |
+
def forward(self, x, c, feat_rope=None):
|
| 258 |
+
# Apply activation
|
| 259 |
+
c = self.act(c)
|
| 260 |
+
|
| 261 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
| 262 |
+
|
| 263 |
+
# Attention block
|
| 264 |
+
norm_x = self.norm1(x)
|
| 265 |
+
modulated_x = modulate(norm_x, shift_msa, scale_msa)
|
| 266 |
+
attn_out = self.attn(modulated_x, rope=feat_rope)
|
| 267 |
+
x = x + gate_msa.unsqueeze(1) * attn_out
|
| 268 |
+
|
| 269 |
+
# MLP block
|
| 270 |
+
norm_x = self.norm2(x)
|
| 271 |
+
modulated_x = modulate(norm_x, shift_mlp, scale_mlp)
|
| 272 |
+
mlp_out = self.mlp(modulated_x)
|
| 273 |
+
x = x + gate_mlp.unsqueeze(1) * mlp_out
|
| 274 |
+
|
| 275 |
+
return x
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
| 279 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 280 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 281 |
+
grid = np.meshgrid(grid_w, grid_h)
|
| 282 |
+
grid = np.stack(grid, axis=0)
|
| 283 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 284 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 285 |
+
if cls_token and extra_tokens > 0:
|
| 286 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 287 |
+
return pos_embed
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 291 |
+
if embed_dim % 2 != 0:
|
| 292 |
+
raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
|
| 293 |
+
|
| 294 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
|
| 295 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
|
| 296 |
+
emb = np.concatenate([emb_h, emb_w], axis=1)
|
| 297 |
+
return emb
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 301 |
+
if embed_dim % 2 != 0:
|
| 302 |
+
raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
|
| 303 |
+
|
| 304 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 305 |
+
omega /= embed_dim / 2.0
|
| 306 |
+
omega = 1.0 / 10000**omega
|
| 307 |
+
|
| 308 |
+
pos = pos.reshape(-1)
|
| 309 |
+
out = np.einsum("m,d->md", pos, omega)
|
| 310 |
+
|
| 311 |
+
emb_sin = np.sin(out)
|
| 312 |
+
emb_cos = np.cos(out)
|
| 313 |
+
|
| 314 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1)
|
| 315 |
+
return emb
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class JiTTransformer2DModel(ModelMixin, ConfigMixin):
|
| 319 |
+
r"""
|
| 320 |
+
A 2D Transformer for pixel-space class-conditional generation with JiT
|
| 321 |
+
([Back to Basics: Let Denoising Generative Models Denoise](https://arxiv.org/abs/2511.13720)).
|
| 322 |
+
|
| 323 |
+
Parameters:
|
| 324 |
+
sample_size (`int`, defaults to `256`):
|
| 325 |
+
Input image resolution (height and width).
|
| 326 |
+
patch_size (`int`, defaults to `16`):
|
| 327 |
+
Patch size for the bottleneck patch embedder.
|
| 328 |
+
in_channels (`int`, defaults to `3`):
|
| 329 |
+
Number of input image channels.
|
| 330 |
+
hidden_size (`int`, defaults to `768`):
|
| 331 |
+
Transformer hidden dimension.
|
| 332 |
+
num_layers (`int`, defaults to `12`):
|
| 333 |
+
Number of JiT transformer blocks.
|
| 334 |
+
num_attention_heads (`int`, defaults to `12`):
|
| 335 |
+
Number of attention heads per block.
|
| 336 |
+
mlp_ratio (`float`, defaults to `4.0`):
|
| 337 |
+
MLP hidden dimension multiplier.
|
| 338 |
+
attention_dropout (`float`, defaults to `0.0`):
|
| 339 |
+
Attention dropout in the middle quarter of blocks.
|
| 340 |
+
dropout (`float`, defaults to `0.0`):
|
| 341 |
+
Projection dropout in the middle quarter of blocks.
|
| 342 |
+
num_classes (`int`, defaults to `1000`):
|
| 343 |
+
Number of class labels (null label uses index `num_classes` for CFG).
|
| 344 |
+
bottleneck_dim (`int`, defaults to `128`):
|
| 345 |
+
PCA bottleneck dimension in the patch embedder.
|
| 346 |
+
in_context_len (`int`, defaults to `32`):
|
| 347 |
+
Number of in-context class tokens prepended mid-network.
|
| 348 |
+
in_context_start (`int`, defaults to `4`):
|
| 349 |
+
Block index at which in-context tokens are inserted.
|
| 350 |
+
norm_eps (`float`, defaults to `1e-6`):
|
| 351 |
+
Epsilon for RMSNorm layers.
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
_supports_gradient_checkpointing = True
|
| 355 |
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
| 356 |
+
|
| 357 |
+
@register_to_config
|
| 358 |
+
def __init__(
|
| 359 |
+
self,
|
| 360 |
+
sample_size: int = 256,
|
| 361 |
+
patch_size: int = 16,
|
| 362 |
+
in_channels: int = 3,
|
| 363 |
+
hidden_size: int = 768,
|
| 364 |
+
num_layers: int = 12,
|
| 365 |
+
num_attention_heads: int = 12,
|
| 366 |
+
mlp_ratio: float = 4.0,
|
| 367 |
+
attention_dropout: float = 0.0,
|
| 368 |
+
dropout: float = 0.0,
|
| 369 |
+
num_classes: int = 1000,
|
| 370 |
+
bottleneck_dim: int = 128,
|
| 371 |
+
in_context_len: int = 32,
|
| 372 |
+
in_context_start: int = 4,
|
| 373 |
+
norm_eps: float = 1e-6,
|
| 374 |
+
):
|
| 375 |
+
super().__init__()
|
| 376 |
+
self.sample_size = sample_size
|
| 377 |
+
self.patch_size = patch_size
|
| 378 |
+
self.in_channels = in_channels
|
| 379 |
+
self.out_channels = in_channels
|
| 380 |
+
self.hidden_size = hidden_size
|
| 381 |
+
self.num_layers = num_layers
|
| 382 |
+
self.num_attention_heads = num_attention_heads
|
| 383 |
+
self.in_context_len = in_context_len
|
| 384 |
+
self.in_context_start = in_context_start
|
| 385 |
+
self.norm_eps = norm_eps
|
| 386 |
+
self.gradient_checkpointing = False
|
| 387 |
+
|
| 388 |
+
# Time and Class Embedding
|
| 389 |
+
self.t_embedder = JiTTimestepEmbedder(hidden_size)
|
| 390 |
+
self.y_embedder = JiTLabelEmbedder(num_classes, hidden_size)
|
| 391 |
+
|
| 392 |
+
# Patch Embedding
|
| 393 |
+
self.x_embedder = JiTPatchEmbed(
|
| 394 |
+
img_size=sample_size,
|
| 395 |
+
patch_size=patch_size,
|
| 396 |
+
in_chans=in_channels,
|
| 397 |
+
pca_dim=bottleneck_dim,
|
| 398 |
+
embed_dim=hidden_size,
|
| 399 |
+
bias=True,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# Positional Embedding (Fixed Sin-Cos)
|
| 403 |
+
num_patches = self.x_embedder.num_patches
|
| 404 |
+
pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches**0.5))
|
| 405 |
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
|
| 406 |
+
|
| 407 |
+
# In-context Embedding
|
| 408 |
+
if self.in_context_len > 0:
|
| 409 |
+
self.in_context_posemb = nn.Parameter(torch.zeros(1, self.in_context_len, hidden_size))
|
| 410 |
+
|
| 411 |
+
# RoPE
|
| 412 |
+
half_head_dim = hidden_size // num_attention_heads // 2
|
| 413 |
+
hw_seq_len = sample_size // patch_size
|
| 414 |
+
self.feat_rope = JiTRotaryEmbedding(dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=0)
|
| 415 |
+
self.feat_rope_incontext = JiTRotaryEmbedding(
|
| 416 |
+
dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=self.in_context_len
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
# Blocks
|
| 420 |
+
self.blocks = nn.ModuleList(
|
| 421 |
+
[
|
| 422 |
+
JiTBlock(
|
| 423 |
+
hidden_size,
|
| 424 |
+
num_attention_heads,
|
| 425 |
+
mlp_ratio=mlp_ratio,
|
| 426 |
+
attn_drop=attention_dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
|
| 427 |
+
proj_drop=dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
|
| 428 |
+
eps=norm_eps,
|
| 429 |
+
)
|
| 430 |
+
for i in range(num_layers)
|
| 431 |
+
]
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# Final Layer
|
| 435 |
+
self.norm_final = RMSNorm(hidden_size, eps=norm_eps)
|
| 436 |
+
self.linear_final = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
|
| 437 |
+
self.act_final = nn.SiLU()
|
| 438 |
+
self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 439 |
+
|
| 440 |
+
def forward(
|
| 441 |
+
self,
|
| 442 |
+
hidden_states: torch.Tensor,
|
| 443 |
+
timestep: torch.LongTensor,
|
| 444 |
+
class_labels: torch.LongTensor,
|
| 445 |
+
return_dict: bool = True,
|
| 446 |
+
):
|
| 447 |
+
|
| 448 |
+
t_emb = self.t_embedder(timestep, dtype=hidden_states.dtype)
|
| 449 |
+
y_emb = self.y_embedder(class_labels)
|
| 450 |
+
|
| 451 |
+
# Ensure embeddings match hidden_states dtype
|
| 452 |
+
y_emb = y_emb.to(dtype=hidden_states.dtype)
|
| 453 |
+
|
| 454 |
+
c = t_emb + y_emb
|
| 455 |
+
|
| 456 |
+
# Patch Embed
|
| 457 |
+
x = self.x_embedder(hidden_states)
|
| 458 |
+
x = x + self.pos_embed.to(x.dtype)
|
| 459 |
+
|
| 460 |
+
# Blocks
|
| 461 |
+
for i, block in enumerate(self.blocks):
|
| 462 |
+
if self.in_context_len > 0 and i == self.in_context_start:
|
| 463 |
+
in_context_tokens = y_emb.unsqueeze(1).repeat(1, self.in_context_len, 1)
|
| 464 |
+
in_context_tokens = in_context_tokens + self.in_context_posemb.to(in_context_tokens.dtype)
|
| 465 |
+
x = torch.cat([in_context_tokens, x], dim=1)
|
| 466 |
+
|
| 467 |
+
rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
|
| 468 |
+
|
| 469 |
+
if self.training and self.gradient_checkpointing:
|
| 470 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 471 |
+
block,
|
| 472 |
+
x,
|
| 473 |
+
c,
|
| 474 |
+
rope,
|
| 475 |
+
use_reentrant=False,
|
| 476 |
+
)
|
| 477 |
+
else:
|
| 478 |
+
x = block(x, c, feat_rope=rope)
|
| 479 |
+
|
| 480 |
+
# Slice off in-context tokens
|
| 481 |
+
if self.in_context_len > 0:
|
| 482 |
+
x = x[:, self.in_context_len :]
|
| 483 |
+
|
| 484 |
+
# Final Layer
|
| 485 |
+
c = self.act_final(c)
|
| 486 |
+
shift, scale = self.adaLN_modulation_final(c).chunk(2, dim=1)
|
| 487 |
+
|
| 488 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 489 |
+
x = self.linear_final(x)
|
| 490 |
+
|
| 491 |
+
# Unpatchify
|
| 492 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 493 |
+
x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels))
|
| 494 |
+
x = torch.einsum("nhwpqc->nchpwq", x)
|
| 495 |
+
output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size))
|
| 496 |
+
|
| 497 |
+
if not return_dict:
|
| 498 |
+
return (output,)
|
| 499 |
+
|
| 500 |
+
return Transformer2DModelOutput(sample=output)
|
JiT-H-32/model_index.json
CHANGED
|
@@ -1,8 +1,15 @@
|
|
| 1 |
{
|
| 2 |
-
"_class_name":
|
|
|
|
|
|
|
|
|
|
| 3 |
"_diffusers_version": "0.36.0",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"transformer": [
|
| 5 |
-
"
|
| 6 |
"JiTTransformer2DModel"
|
| 7 |
]
|
| 8 |
-
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"_class_name": [
|
| 3 |
+
"pipeline",
|
| 4 |
+
"JiTPipeline"
|
| 5 |
+
],
|
| 6 |
"_diffusers_version": "0.36.0",
|
| 7 |
+
"scheduler": [
|
| 8 |
+
"scheduling_jit",
|
| 9 |
+
"JiTScheduler"
|
| 10 |
+
],
|
| 11 |
"transformer": [
|
| 12 |
+
"jit_transformer_2d",
|
| 13 |
"JiTTransformer2DModel"
|
| 14 |
]
|
| 15 |
+
}
|
JiT-H-32/pipeline.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import importlib
|
| 18 |
+
import json
|
| 19 |
+
import sys
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 26 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
RECOMMENDED_NOISE_BY_SIZE = {
|
| 30 |
+
256: 1.0,
|
| 31 |
+
512: 2.0,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class JiTPipeline(DiffusionPipeline):
|
| 36 |
+
r"""
|
| 37 |
+
Pipeline for image generation using JiT (Just image Transformer).
|
| 38 |
+
|
| 39 |
+
Parameters:
|
| 40 |
+
transformer ([`JiTTransformer2DModel`]):
|
| 41 |
+
A class-conditioned `JiTTransformer2DModel` to denoise the images.
|
| 42 |
+
scheduler ([`JiTScheduler`]):
|
| 43 |
+
Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler).
|
| 44 |
+
id2label (`dict[int, str]`, *optional*):
|
| 45 |
+
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
|
| 46 |
+
id2label_cn (`dict[int, str]`, *optional*):
|
| 47 |
+
ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
model_cpu_offload_seq = "transformer"
|
| 51 |
+
|
| 52 |
+
@classmethod
|
| 53 |
+
def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
|
| 54 |
+
"""Load a self-contained variant folder locally or from the Hub.
|
| 55 |
+
|
| 56 |
+
Examples:
|
| 57 |
+
JiTPipeline.from_pretrained(".")
|
| 58 |
+
JiTPipeline.from_pretrained("./JiT-H-32")
|
| 59 |
+
DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
|
| 60 |
+
"""
|
| 61 |
+
repo_root = Path(__file__).resolve().parent
|
| 62 |
+
|
| 63 |
+
if pretrained_model_name_or_path in (None, "", "."):
|
| 64 |
+
variant = repo_root
|
| 65 |
+
elif (
|
| 66 |
+
isinstance(pretrained_model_name_or_path, str)
|
| 67 |
+
and "/" in pretrained_model_name_or_path
|
| 68 |
+
and not Path(pretrained_model_name_or_path).exists()
|
| 69 |
+
):
|
| 70 |
+
from huggingface_hub import snapshot_download
|
| 71 |
+
|
| 72 |
+
hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
|
| 73 |
+
if subfolder:
|
| 74 |
+
hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"])
|
| 75 |
+
cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
|
| 76 |
+
variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
|
| 77 |
+
else:
|
| 78 |
+
variant = Path(pretrained_model_name_or_path)
|
| 79 |
+
if not variant.is_absolute():
|
| 80 |
+
candidate = (Path.cwd() / variant).resolve()
|
| 81 |
+
variant = candidate if candidate.exists() else (repo_root / variant).resolve()
|
| 82 |
+
if subfolder:
|
| 83 |
+
variant = variant / subfolder
|
| 84 |
+
|
| 85 |
+
model_kwargs = dict(kwargs)
|
| 86 |
+
inserted: List[str] = []
|
| 87 |
+
|
| 88 |
+
def _load_component(folder: str, module_name: str, class_name: str):
|
| 89 |
+
comp_dir = variant / folder
|
| 90 |
+
module_path = comp_dir / f"{module_name}.py"
|
| 91 |
+
has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
|
| 92 |
+
if not module_path.exists() or not has_weights:
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
comp_path = str(comp_dir)
|
| 96 |
+
if comp_path not in sys.path:
|
| 97 |
+
sys.path.insert(0, comp_path)
|
| 98 |
+
inserted.append(comp_path)
|
| 99 |
+
|
| 100 |
+
module = importlib.import_module(module_name)
|
| 101 |
+
component_cls = getattr(module, class_name)
|
| 102 |
+
return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
|
| 106 |
+
scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler")
|
| 107 |
+
|
| 108 |
+
if transformer is None:
|
| 109 |
+
raise ValueError(f"No loadable transformer found under {variant}")
|
| 110 |
+
|
| 111 |
+
variant_path = str(variant)
|
| 112 |
+
id2label, id2label_cn = cls._load_labels_for_variant(variant_path)
|
| 113 |
+
|
| 114 |
+
pipe = cls(
|
| 115 |
+
transformer=transformer,
|
| 116 |
+
scheduler=scheduler,
|
| 117 |
+
id2label=id2label,
|
| 118 |
+
id2label_cn=id2label_cn,
|
| 119 |
+
)
|
| 120 |
+
if variant_path and hasattr(pipe, "register_to_config"):
|
| 121 |
+
pipe.register_to_config(_name_or_path=variant_path)
|
| 122 |
+
return pipe
|
| 123 |
+
finally:
|
| 124 |
+
for comp_path in inserted:
|
| 125 |
+
if comp_path in sys.path:
|
| 126 |
+
sys.path.remove(comp_path)
|
| 127 |
+
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
transformer,
|
| 131 |
+
scheduler,
|
| 132 |
+
id2label: Optional[Dict[int, str]] = None,
|
| 133 |
+
id2label_cn: Optional[Dict[int, str]] = None,
|
| 134 |
+
):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.register_modules(transformer=transformer, scheduler=scheduler)
|
| 137 |
+
|
| 138 |
+
self._id2label = id2label or {}
|
| 139 |
+
self._id2label_cn = id2label_cn or {}
|
| 140 |
+
self.labels = self._build_label2id(self._id2label)
|
| 141 |
+
self.labels_cn = self._build_label2id(self._id2label_cn)
|
| 142 |
+
|
| 143 |
+
def _ensure_labels_loaded(self) -> None:
|
| 144 |
+
if self._id2label or self._id2label_cn:
|
| 145 |
+
return
|
| 146 |
+
loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None))
|
| 147 |
+
if loaded_en:
|
| 148 |
+
self._id2label = loaded_en
|
| 149 |
+
self.labels = self._build_label2id(self._id2label)
|
| 150 |
+
if loaded_cn:
|
| 151 |
+
self._id2label_cn = loaded_cn
|
| 152 |
+
self.labels_cn = self._build_label2id(self._id2label_cn)
|
| 153 |
+
|
| 154 |
+
@staticmethod
|
| 155 |
+
def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]:
|
| 156 |
+
if not variant_path:
|
| 157 |
+
return None
|
| 158 |
+
variant_dir = Path(variant_path).resolve()
|
| 159 |
+
labels_dir = variant_dir.parent / "labels"
|
| 160 |
+
return labels_dir if labels_dir.is_dir() else None
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]:
|
| 164 |
+
filename = "id2label_en.json" if lang == "en" else "id2label_cn.json"
|
| 165 |
+
path = labels_dir / filename
|
| 166 |
+
if not path.exists():
|
| 167 |
+
raise FileNotFoundError(path)
|
| 168 |
+
raw = json.loads(path.read_text(encoding="utf-8"))
|
| 169 |
+
return {int(key): value for key, value in raw.items()}
|
| 170 |
+
|
| 171 |
+
@classmethod
|
| 172 |
+
def _load_labels_for_variant(
|
| 173 |
+
cls,
|
| 174 |
+
variant_path: Optional[str],
|
| 175 |
+
) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]:
|
| 176 |
+
labels_dir = cls._labels_dir_for_variant(variant_path)
|
| 177 |
+
if labels_dir is None:
|
| 178 |
+
return None, None
|
| 179 |
+
try:
|
| 180 |
+
return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn")
|
| 181 |
+
except FileNotFoundError:
|
| 182 |
+
return None, None
|
| 183 |
+
|
| 184 |
+
@staticmethod
|
| 185 |
+
def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
|
| 186 |
+
label2id: Dict[str, int] = {}
|
| 187 |
+
for class_id, value in id2label.items():
|
| 188 |
+
for synonym in value.split(","):
|
| 189 |
+
synonym = synonym.strip()
|
| 190 |
+
if synonym:
|
| 191 |
+
label2id[synonym] = int(class_id)
|
| 192 |
+
return dict(sorted(label2id.items()))
|
| 193 |
+
|
| 194 |
+
@property
|
| 195 |
+
def id2label(self) -> Dict[int, str]:
|
| 196 |
+
"""ImageNet class id to English label string (comma-separated synonyms)."""
|
| 197 |
+
self._ensure_labels_loaded()
|
| 198 |
+
return self._id2label
|
| 199 |
+
|
| 200 |
+
@property
|
| 201 |
+
def id2label_cn(self) -> Dict[int, str]:
|
| 202 |
+
"""ImageNet class id to Chinese label string (comma-separated synonyms)."""
|
| 203 |
+
self._ensure_labels_loaded()
|
| 204 |
+
return self._id2label_cn
|
| 205 |
+
|
| 206 |
+
def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]:
|
| 207 |
+
r"""
|
| 208 |
+
Map ImageNet label strings to class ids.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
label (`str` or `list[str]`):
|
| 212 |
+
One or more label strings. Each string must match a synonym in `id2label` (English)
|
| 213 |
+
or `id2label_cn` (Chinese).
|
| 214 |
+
lang (`str`, *optional*, defaults to `"en"`):
|
| 215 |
+
`"en"` uses English synonyms; `"cn"` uses Chinese synonyms.
|
| 216 |
+
"""
|
| 217 |
+
if lang not in ("en", "cn"):
|
| 218 |
+
raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.")
|
| 219 |
+
|
| 220 |
+
self._ensure_labels_loaded()
|
| 221 |
+
label2id = self.labels if lang == "en" else self.labels_cn
|
| 222 |
+
if not label2id:
|
| 223 |
+
raise ValueError(
|
| 224 |
+
f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if isinstance(label, str):
|
| 228 |
+
label = [label]
|
| 229 |
+
|
| 230 |
+
missing = [item for item in label if item not in label2id]
|
| 231 |
+
if missing:
|
| 232 |
+
preview = ", ".join(list(label2id.keys())[:8])
|
| 233 |
+
raise ValueError(
|
| 234 |
+
f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
|
| 235 |
+
)
|
| 236 |
+
return [label2id[item] for item in label]
|
| 237 |
+
|
| 238 |
+
def _normalize_class_labels(
|
| 239 |
+
self,
|
| 240 |
+
class_labels: Union[int, str, List[Union[int, str]]],
|
| 241 |
+
) -> List[int]:
|
| 242 |
+
if isinstance(class_labels, int):
|
| 243 |
+
return [class_labels]
|
| 244 |
+
|
| 245 |
+
if isinstance(class_labels, str):
|
| 246 |
+
return self.get_label_ids(class_labels)
|
| 247 |
+
|
| 248 |
+
if class_labels and isinstance(class_labels[0], str):
|
| 249 |
+
self._ensure_labels_loaded()
|
| 250 |
+
if all(label in self.labels for label in class_labels):
|
| 251 |
+
return self.get_label_ids(class_labels, lang="en")
|
| 252 |
+
if all(label in self.labels_cn for label in class_labels):
|
| 253 |
+
return self.get_label_ids(class_labels, lang="cn")
|
| 254 |
+
raise ValueError(
|
| 255 |
+
"Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` "
|
| 256 |
+
"or Chinese synonyms from `pipe.labels_cn`."
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
return list(class_labels)
|
| 260 |
+
|
| 261 |
+
def _predict_velocity(
|
| 262 |
+
self,
|
| 263 |
+
z_value: torch.Tensor,
|
| 264 |
+
t: torch.Tensor,
|
| 265 |
+
class_labels: torch.Tensor,
|
| 266 |
+
class_null: torch.Tensor,
|
| 267 |
+
do_classifier_free_guidance: bool,
|
| 268 |
+
guidance_scale: float,
|
| 269 |
+
guidance_interval_min: float,
|
| 270 |
+
guidance_interval_max: float,
|
| 271 |
+
) -> torch.Tensor:
|
| 272 |
+
t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype)
|
| 273 |
+
if do_classifier_free_guidance:
|
| 274 |
+
z_in = torch.cat([z_value, z_value], dim=0)
|
| 275 |
+
labels = torch.cat([class_labels, class_null], dim=0)
|
| 276 |
+
else:
|
| 277 |
+
z_in = z_value
|
| 278 |
+
labels = class_labels
|
| 279 |
+
|
| 280 |
+
t_batch = t.flatten().expand(z_in.shape[0])
|
| 281 |
+
x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample
|
| 282 |
+
v = self.scheduler.velocity_from_prediction(z_in, x_pred, t)
|
| 283 |
+
|
| 284 |
+
if not do_classifier_free_guidance:
|
| 285 |
+
return v
|
| 286 |
+
|
| 287 |
+
v_cond, v_uncond = v.chunk(2, dim=0)
|
| 288 |
+
interval_mask = t < guidance_interval_max
|
| 289 |
+
if guidance_interval_min != 0.0:
|
| 290 |
+
interval_mask = interval_mask & (t > guidance_interval_min)
|
| 291 |
+
scale = torch.where(
|
| 292 |
+
interval_mask,
|
| 293 |
+
torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype),
|
| 294 |
+
torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype),
|
| 295 |
+
)
|
| 296 |
+
return v_uncond + scale * (v_cond - v_uncond)
|
| 297 |
+
|
| 298 |
+
def _run_sampler(
|
| 299 |
+
self,
|
| 300 |
+
latents: torch.Tensor,
|
| 301 |
+
class_labels: torch.Tensor,
|
| 302 |
+
class_null: torch.Tensor,
|
| 303 |
+
num_inference_steps: int,
|
| 304 |
+
do_classifier_free_guidance: bool,
|
| 305 |
+
guidance_scale: float,
|
| 306 |
+
guidance_interval_min: float,
|
| 307 |
+
guidance_interval_max: float,
|
| 308 |
+
sampling_method: str,
|
| 309 |
+
) -> torch.Tensor:
|
| 310 |
+
device = latents.device
|
| 311 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method)
|
| 312 |
+
timesteps = self.scheduler.timesteps
|
| 313 |
+
|
| 314 |
+
for i in self.progress_bar(range(num_inference_steps - 1)):
|
| 315 |
+
t = timesteps[i]
|
| 316 |
+
t_next = timesteps[i + 1]
|
| 317 |
+
v = self._predict_velocity(
|
| 318 |
+
latents,
|
| 319 |
+
t,
|
| 320 |
+
class_labels,
|
| 321 |
+
class_null,
|
| 322 |
+
do_classifier_free_guidance,
|
| 323 |
+
guidance_scale,
|
| 324 |
+
guidance_interval_min,
|
| 325 |
+
guidance_interval_max,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
if sampling_method == "heun":
|
| 329 |
+
latents_euler = latents + (t_next - t) * v
|
| 330 |
+
v_next = self._predict_velocity(
|
| 331 |
+
latents_euler,
|
| 332 |
+
t_next,
|
| 333 |
+
class_labels,
|
| 334 |
+
class_null,
|
| 335 |
+
do_classifier_free_guidance,
|
| 336 |
+
guidance_scale,
|
| 337 |
+
guidance_interval_min,
|
| 338 |
+
guidance_interval_max,
|
| 339 |
+
)
|
| 340 |
+
latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample
|
| 341 |
+
else:
|
| 342 |
+
latents = self.scheduler.step(v, t, latents).prev_sample
|
| 343 |
+
|
| 344 |
+
t = timesteps[-2]
|
| 345 |
+
t_next = timesteps[-1]
|
| 346 |
+
v = self._predict_velocity(
|
| 347 |
+
latents,
|
| 348 |
+
t,
|
| 349 |
+
class_labels,
|
| 350 |
+
class_null,
|
| 351 |
+
do_classifier_free_guidance,
|
| 352 |
+
guidance_scale,
|
| 353 |
+
guidance_interval_min,
|
| 354 |
+
guidance_interval_max,
|
| 355 |
+
)
|
| 356 |
+
return latents + (t_next - t) * v
|
| 357 |
+
|
| 358 |
+
@torch.inference_mode()
|
| 359 |
+
def __call__(
|
| 360 |
+
self,
|
| 361 |
+
class_labels: Union[int, str, List[Union[int, str]]],
|
| 362 |
+
guidance_scale: Optional[float] = None,
|
| 363 |
+
guidance_interval_min: float = 0.1,
|
| 364 |
+
guidance_interval_max: float = 1.0,
|
| 365 |
+
noise_scale: Optional[float] = None,
|
| 366 |
+
t_eps: Optional[float] = None,
|
| 367 |
+
sampling_method: Optional[str] = None,
|
| 368 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 369 |
+
num_inference_steps: int = 50,
|
| 370 |
+
output_type: Optional[str] = "pil",
|
| 371 |
+
return_dict: bool = True,
|
| 372 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
| 373 |
+
r"""
|
| 374 |
+
Generate class-conditional images.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
class_labels (`int`, `str`, `list[int]`, or `list[str]`):
|
| 378 |
+
ImageNet class indices or human-readable label strings (English or Chinese).
|
| 379 |
+
guidance_scale (`float`, *optional*):
|
| 380 |
+
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
|
| 381 |
+
guidance_interval_min (`float`, defaults to `0.1`):
|
| 382 |
+
Lower bound of the CFG interval in flow time `t in [0, 1]`.
|
| 383 |
+
guidance_interval_max (`float`, defaults to `1.0`):
|
| 384 |
+
Upper bound of the CFG interval in flow time.
|
| 385 |
+
noise_scale (`float`, *optional*):
|
| 386 |
+
Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
|
| 387 |
+
t_eps (`float`, *optional*):
|
| 388 |
+
Epsilon clamp for the `1 - t` denominator (scheduler config by default).
|
| 389 |
+
sampling_method (`str`, *optional*):
|
| 390 |
+
`"heun"` or `"euler"`. Defaults to the scheduler config (`heun`).
|
| 391 |
+
generator (`torch.Generator`, *optional*):
|
| 392 |
+
RNG for reproducibility.
|
| 393 |
+
num_inference_steps (`int`, defaults to `50`):
|
| 394 |
+
Number of solver steps (at least 2).
|
| 395 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 396 |
+
`"pil"`, `"np"`, or `"pt"`.
|
| 397 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 398 |
+
Return [`ImagePipelineOutput`] if True.
|
| 399 |
+
"""
|
| 400 |
+
solver = sampling_method or self.scheduler.config.solver
|
| 401 |
+
if solver not in {"heun", "euler"}:
|
| 402 |
+
raise ValueError("sampling_method must be one of: 'heun', 'euler'.")
|
| 403 |
+
if num_inference_steps < 2:
|
| 404 |
+
raise ValueError("num_inference_steps must be >= 2.")
|
| 405 |
+
|
| 406 |
+
if t_eps is not None:
|
| 407 |
+
self.scheduler.register_to_config(t_eps=t_eps)
|
| 408 |
+
|
| 409 |
+
class_label_ids = self._normalize_class_labels(class_labels)
|
| 410 |
+
do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
|
| 411 |
+
|
| 412 |
+
batch_size = len(class_label_ids)
|
| 413 |
+
image_size = int(self.transformer.config.sample_size)
|
| 414 |
+
channels = int(self.transformer.config.in_channels)
|
| 415 |
+
null_class_val = int(self.transformer.config.num_classes)
|
| 416 |
+
|
| 417 |
+
if guidance_scale is None:
|
| 418 |
+
guidance_scale = 1.0
|
| 419 |
+
if noise_scale is None:
|
| 420 |
+
noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0)
|
| 421 |
+
|
| 422 |
+
latents = (
|
| 423 |
+
randn_tensor(
|
| 424 |
+
shape=(batch_size, channels, image_size, image_size),
|
| 425 |
+
generator=generator,
|
| 426 |
+
device=self._execution_device,
|
| 427 |
+
dtype=self.transformer.dtype,
|
| 428 |
+
)
|
| 429 |
+
* noise_scale
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
|
| 433 |
+
class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
|
| 434 |
+
class_null = torch.full_like(class_labels_t, null_class_val)
|
| 435 |
+
|
| 436 |
+
latents = self._run_sampler(
|
| 437 |
+
latents,
|
| 438 |
+
class_labels_t,
|
| 439 |
+
class_null,
|
| 440 |
+
num_inference_steps,
|
| 441 |
+
do_classifier_free_guidance,
|
| 442 |
+
guidance_scale,
|
| 443 |
+
guidance_interval_min,
|
| 444 |
+
guidance_interval_max,
|
| 445 |
+
solver,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
|
| 449 |
+
if output_type == "pt":
|
| 450 |
+
images = images_pt
|
| 451 |
+
elif output_type == "np":
|
| 452 |
+
images = images_pt.permute(0, 2, 3, 1).numpy()
|
| 453 |
+
else:
|
| 454 |
+
images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy())
|
| 455 |
+
|
| 456 |
+
self.maybe_free_model_hooks()
|
| 457 |
+
|
| 458 |
+
if not return_dict:
|
| 459 |
+
return (images,)
|
| 460 |
+
return ImagePipelineOutput(images=images)
|
JiT-H-32/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "JiTScheduler",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"num_train_timesteps": 1000,
|
| 5 |
+
"t_eps": 0.05,
|
| 6 |
+
"solver": "heun"
|
| 7 |
+
}
|
JiT-H-32/scheduler/scheduling_jit.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 21 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
| 22 |
+
from diffusers.utils import BaseOutput
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class JiTSchedulerOutput(BaseOutput):
|
| 27 |
+
"""
|
| 28 |
+
Output class for the JiT scheduler's `step` function.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
prev_sample (`torch.Tensor`):
|
| 32 |
+
Updated sample after one solver step along the JiT flow-time grid.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
prev_sample: torch.Tensor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class JiTScheduler(SchedulerMixin, ConfigMixin):
|
| 39 |
+
"""
|
| 40 |
+
Manual flow-matching scheduler for JiT checkpoints.
|
| 41 |
+
|
| 42 |
+
Uses a linear flow-time grid `t in [0, 1]` (increasing), matching the official JiT
|
| 43 |
+
sampler. Velocity is `v = (x_pred - z) / (1 - t)`; integration is explicit Euler or
|
| 44 |
+
Heun along that grid.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
order = 2
|
| 48 |
+
|
| 49 |
+
@register_to_config
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
num_train_timesteps: int = 1000,
|
| 53 |
+
t_eps: float = 5e-2,
|
| 54 |
+
solver: str = "heun",
|
| 55 |
+
):
|
| 56 |
+
if solver not in {"heun", "euler"}:
|
| 57 |
+
raise ValueError("solver must be one of: 'heun', 'euler'.")
|
| 58 |
+
self.timesteps: Optional[torch.Tensor] = None
|
| 59 |
+
self.sigmas: Optional[List[float]] = None
|
| 60 |
+
self.num_inference_steps: Optional[int] = None
|
| 61 |
+
self._step_index: Optional[int] = None
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def init_noise_sigma(self) -> float:
|
| 65 |
+
return 1.0
|
| 66 |
+
|
| 67 |
+
def set_timesteps(
|
| 68 |
+
self,
|
| 69 |
+
num_inference_steps: int,
|
| 70 |
+
device: Union[str, torch.device, None] = None,
|
| 71 |
+
solver: Optional[str] = None,
|
| 72 |
+
) -> None:
|
| 73 |
+
if num_inference_steps < 2:
|
| 74 |
+
raise ValueError("num_inference_steps must be >= 2.")
|
| 75 |
+
|
| 76 |
+
self.num_inference_steps = num_inference_steps
|
| 77 |
+
self.timesteps = torch.linspace(
|
| 78 |
+
0.0,
|
| 79 |
+
1.0,
|
| 80 |
+
num_inference_steps + 1,
|
| 81 |
+
device=device,
|
| 82 |
+
dtype=torch.float32,
|
| 83 |
+
)
|
| 84 |
+
sigma_grid = torch.linspace(0.0, 1.0, num_inference_steps, device=device, dtype=torch.float32)
|
| 85 |
+
self.sigmas = (1.0 - sigma_grid).tolist()
|
| 86 |
+
self._step_index = 0
|
| 87 |
+
if solver is not None:
|
| 88 |
+
self.register_to_config(solver=solver)
|
| 89 |
+
|
| 90 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
| 91 |
+
del timestep
|
| 92 |
+
return sample
|
| 93 |
+
|
| 94 |
+
def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int:
|
| 95 |
+
if self._step_index is not None:
|
| 96 |
+
return self._step_index
|
| 97 |
+
if self.timesteps is None:
|
| 98 |
+
raise ValueError("Call `set_timesteps` before `step`.")
|
| 99 |
+
if timestep is None:
|
| 100 |
+
return 0
|
| 101 |
+
t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0])
|
| 102 |
+
matches = (self.timesteps - t_value).abs() < 1e-6
|
| 103 |
+
if matches.any():
|
| 104 |
+
return int(matches.nonzero(as_tuple=False)[0].item())
|
| 105 |
+
return 0
|
| 106 |
+
|
| 107 |
+
def step(
|
| 108 |
+
self,
|
| 109 |
+
model_output: torch.Tensor,
|
| 110 |
+
timestep: Union[float, torch.Tensor, None],
|
| 111 |
+
sample: torch.Tensor,
|
| 112 |
+
model_output_next: Optional[torch.Tensor] = None,
|
| 113 |
+
return_dict: bool = True,
|
| 114 |
+
) -> Union[JiTSchedulerOutput, Tuple[torch.Tensor]]:
|
| 115 |
+
"""
|
| 116 |
+
Integrate one step on the linear `t` grid.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
model_output (`torch.Tensor`):
|
| 120 |
+
Velocity `v = (x_pred - z) / (1 - t)` at the current time.
|
| 121 |
+
timestep (`float` or `torch.Tensor`, *optional*):
|
| 122 |
+
Current flow time `t`. When omitted, uses the internal step index.
|
| 123 |
+
sample (`torch.Tensor`):
|
| 124 |
+
Current noisy latent `z`.
|
| 125 |
+
model_output_next (`torch.Tensor`, *optional*):
|
| 126 |
+
Velocity at `t_next` (required for Heun intermediate steps).
|
| 127 |
+
"""
|
| 128 |
+
if self.timesteps is None:
|
| 129 |
+
raise ValueError("Call `set_timesteps` before `step`.")
|
| 130 |
+
|
| 131 |
+
step_index = self._resolve_step_index(timestep)
|
| 132 |
+
if step_index >= len(self.timesteps) - 1:
|
| 133 |
+
raise ValueError("Scheduler has already reached the final timestep.")
|
| 134 |
+
|
| 135 |
+
t = self.timesteps[step_index]
|
| 136 |
+
t_next = self.timesteps[step_index + 1]
|
| 137 |
+
dt = t_next - t
|
| 138 |
+
|
| 139 |
+
if self.config.solver == "heun" and model_output_next is not None:
|
| 140 |
+
prev_sample = sample + dt * 0.5 * (model_output + model_output_next)
|
| 141 |
+
else:
|
| 142 |
+
prev_sample = sample + dt * model_output
|
| 143 |
+
|
| 144 |
+
self._step_index = step_index + 1
|
| 145 |
+
|
| 146 |
+
if not return_dict:
|
| 147 |
+
return (prev_sample,)
|
| 148 |
+
return JiTSchedulerOutput(prev_sample=prev_sample)
|
| 149 |
+
|
| 150 |
+
def velocity_from_prediction(
|
| 151 |
+
self,
|
| 152 |
+
sample: torch.Tensor,
|
| 153 |
+
x_pred: torch.Tensor,
|
| 154 |
+
timestep: Union[float, torch.Tensor],
|
| 155 |
+
) -> torch.Tensor:
|
| 156 |
+
"""Compute JiT velocity `v = (x_pred - z) / (1 - t)` with denominator clamp."""
|
| 157 |
+
t = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype)
|
| 158 |
+
while t.ndim < sample.ndim:
|
| 159 |
+
t = t.unsqueeze(-1)
|
| 160 |
+
denom = (1.0 - t).clamp_min(self.config.t_eps)
|
| 161 |
+
return (x_pred - sample) / denom
|
JiT-H-32/transformer/config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "JiTTransformer2DModel",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"attention_dropout": 0.0,
|
| 5 |
+
"bottleneck_dim": 256,
|
| 6 |
+
"dropout": 0.2,
|
| 7 |
+
"hidden_size": 1280,
|
| 8 |
+
"in_channels": 3,
|
| 9 |
+
"in_context_len": 32,
|
| 10 |
+
"in_context_start": 10,
|
| 11 |
+
"mlp_ratio": 4.0,
|
| 12 |
+
"norm_eps": 1e-06,
|
| 13 |
+
"num_attention_heads": 16,
|
| 14 |
+
"num_classes": 1000,
|
| 15 |
+
"num_layers": 32,
|
| 16 |
+
"patch_size": 32,
|
| 17 |
+
"sample_size": 512
|
| 18 |
+
}
|
JiT-H-32/transformer/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:578fc2f9f4ccaa34c3d2f5076811e101419e5dfd1b20dcca89bbfb29f5f60ab6
|
| 3 |
+
size 3825578920
|
JiT-H-32/transformer/jit_transformer_2d.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 24 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 25 |
+
from diffusers.models.normalization import RMSNorm
|
| 26 |
+
from diffusers.utils import logging
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def broadcat(tensors, dim=-1):
|
| 33 |
+
num_tensors = len(tensors)
|
| 34 |
+
shape_lens = {len(t.shape) for t in tensors}
|
| 35 |
+
if len(shape_lens) != 1:
|
| 36 |
+
raise ValueError("tensors must all have the same number of dimensions")
|
| 37 |
+
shape_len = list(shape_lens)[0]
|
| 38 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
| 39 |
+
dims = list(zip(*(list(t.shape) for t in tensors)))
|
| 40 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
| 41 |
+
|
| 42 |
+
if not all(len(set(t[1])) <= 2 for t in expandable_dims):
|
| 43 |
+
raise ValueError("invalid dimensions for broadcastable concatenation")
|
| 44 |
+
|
| 45 |
+
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
| 46 |
+
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
| 47 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
| 48 |
+
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
| 49 |
+
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
| 50 |
+
return torch.cat(tensors, dim=dim)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def rotate_half(x):
|
| 54 |
+
x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2)
|
| 55 |
+
x1, x2 = x.unbind(dim=-1)
|
| 56 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 57 |
+
return x.view(*x.shape[:-2], -1)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class JiTRotaryEmbedding(nn.Module):
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
dim,
|
| 64 |
+
pt_seq_len=16,
|
| 65 |
+
ft_seq_len=None,
|
| 66 |
+
custom_freqs=None,
|
| 67 |
+
theta=10000,
|
| 68 |
+
num_cls_token=0,
|
| 69 |
+
):
|
| 70 |
+
super().__init__()
|
| 71 |
+
if custom_freqs is not None:
|
| 72 |
+
freqs = custom_freqs
|
| 73 |
+
else:
|
| 74 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 75 |
+
|
| 76 |
+
if ft_seq_len is None:
|
| 77 |
+
ft_seq_len = pt_seq_len
|
| 78 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
| 79 |
+
|
| 80 |
+
freqs = torch.einsum("..., f -> ... f", t, freqs)
|
| 81 |
+
freqs = freqs.repeat_interleave(2, dim=-1)
|
| 82 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
|
| 83 |
+
|
| 84 |
+
if num_cls_token > 0:
|
| 85 |
+
freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D]
|
| 86 |
+
cos_img = freqs_flat.cos()
|
| 87 |
+
sin_img = freqs_flat.sin()
|
| 88 |
+
|
| 89 |
+
# prepend in-context cls token
|
| 90 |
+
_, D = cos_img.shape
|
| 91 |
+
cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype)
|
| 92 |
+
sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype)
|
| 93 |
+
|
| 94 |
+
self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False)
|
| 95 |
+
self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False)
|
| 96 |
+
else:
|
| 97 |
+
self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False)
|
| 98 |
+
self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False)
|
| 99 |
+
|
| 100 |
+
def forward(self, t):
|
| 101 |
+
# Applied on (batch, seq_len, heads, head_dim) tensors from attention.
|
| 102 |
+
seq_len = t.shape[1]
|
| 103 |
+
freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
|
| 104 |
+
freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
|
| 105 |
+
|
| 106 |
+
return t * freqs_cos[:, None, :] + rotate_half(t) * freqs_sin[:, None, :]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def modulate(x, shift, scale):
|
| 110 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class JiTPatchEmbed(nn.Module):
|
| 114 |
+
"""Image to Patch Embedding with Bottleneck"""
|
| 115 |
+
|
| 116 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True):
|
| 117 |
+
super().__init__()
|
| 118 |
+
img_size = (img_size, img_size)
|
| 119 |
+
patch_size = (patch_size, patch_size)
|
| 120 |
+
self.img_size = img_size
|
| 121 |
+
self.patch_size = patch_size
|
| 122 |
+
self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
| 123 |
+
|
| 124 |
+
self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False)
|
| 125 |
+
self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias)
|
| 126 |
+
|
| 127 |
+
def forward(self, x):
|
| 128 |
+
x = self.proj2(self.proj1(x)).flatten(2).transpose(1, 2)
|
| 129 |
+
return x
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class JiTTimestepEmbedder(nn.Module):
|
| 133 |
+
"""
|
| 134 |
+
Embeds scalar timesteps into vector representations.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.mlp = nn.Sequential(
|
| 140 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 141 |
+
nn.SiLU(),
|
| 142 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 143 |
+
)
|
| 144 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 145 |
+
|
| 146 |
+
@staticmethod
|
| 147 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 148 |
+
"""
|
| 149 |
+
Create sinusoidal timestep embeddings.
|
| 150 |
+
"""
|
| 151 |
+
half = dim // 2
|
| 152 |
+
freqs = torch.exp(
|
| 153 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 154 |
+
).to(device=t.device)
|
| 155 |
+
args = t[:, None].float() * freqs[None]
|
| 156 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 157 |
+
if dim % 2:
|
| 158 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 159 |
+
return embedding
|
| 160 |
+
|
| 161 |
+
def forward(self, t, dtype=None):
|
| 162 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 163 |
+
if dtype is not None:
|
| 164 |
+
t_freq = t_freq.to(dtype=dtype)
|
| 165 |
+
t_emb = self.mlp(t_freq)
|
| 166 |
+
return t_emb
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class JiTLabelEmbedder(nn.Module):
|
| 170 |
+
"""
|
| 171 |
+
Embeds class labels into vector representations.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(self, num_classes, hidden_size):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.embedding_table = nn.Embedding(num_classes + 1, hidden_size)
|
| 177 |
+
self.num_classes = num_classes
|
| 178 |
+
|
| 179 |
+
def forward(self, labels):
|
| 180 |
+
embeddings = self.embedding_table(labels)
|
| 181 |
+
return embeddings
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class JiTAttention(nn.Module):
|
| 185 |
+
def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.num_heads = num_heads
|
| 188 |
+
head_dim = dim // num_heads
|
| 189 |
+
|
| 190 |
+
self.q_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
|
| 191 |
+
self.k_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
|
| 192 |
+
|
| 193 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 194 |
+
self.attn_drop = attn_drop
|
| 195 |
+
self.proj = nn.Linear(dim, dim)
|
| 196 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 197 |
+
|
| 198 |
+
def forward(self, x, rope=None):
|
| 199 |
+
B, N, C = x.shape
|
| 200 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 201 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 202 |
+
|
| 203 |
+
q = self.q_norm(q)
|
| 204 |
+
k = self.k_norm(k)
|
| 205 |
+
|
| 206 |
+
if rope is not None:
|
| 207 |
+
q = q.transpose(1, 2)
|
| 208 |
+
k = k.transpose(1, 2)
|
| 209 |
+
q = rope(q)
|
| 210 |
+
k = rope(k)
|
| 211 |
+
q = q.transpose(1, 2)
|
| 212 |
+
k = k.transpose(1, 2)
|
| 213 |
+
|
| 214 |
+
dropout_p = self.attn_drop if self.training else 0.0
|
| 215 |
+
x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
| 216 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 217 |
+
x = self.proj(x)
|
| 218 |
+
x = self.proj_drop(x)
|
| 219 |
+
return x
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class JiTSwiGLUFFN(nn.Module):
|
| 223 |
+
def __init__(self, dim: int, hidden_dim: int, drop=0.0, bias=True) -> None:
|
| 224 |
+
super().__init__()
|
| 225 |
+
hidden_dim = int(hidden_dim * 2 / 3)
|
| 226 |
+
self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias)
|
| 227 |
+
self.w3 = nn.Linear(hidden_dim, dim, bias=bias)
|
| 228 |
+
self.ffn_dropout = nn.Dropout(drop)
|
| 229 |
+
|
| 230 |
+
def forward(self, x):
|
| 231 |
+
x12 = self.w12(x)
|
| 232 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 233 |
+
hidden = F.silu(x1) * x2
|
| 234 |
+
return self.w3(self.ffn_dropout(hidden))
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class JiTBlock(nn.Module):
|
| 238 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
|
| 239 |
+
super().__init__()
|
| 240 |
+
self.norm1 = RMSNorm(hidden_size, eps=eps)
|
| 241 |
+
self.attn = JiTAttention(
|
| 242 |
+
hidden_size,
|
| 243 |
+
num_heads=num_heads,
|
| 244 |
+
qkv_bias=True,
|
| 245 |
+
qk_norm=True,
|
| 246 |
+
attn_drop=attn_drop,
|
| 247 |
+
proj_drop=proj_drop,
|
| 248 |
+
eps=eps,
|
| 249 |
+
)
|
| 250 |
+
self.norm2 = RMSNorm(hidden_size, eps=eps)
|
| 251 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 252 |
+
self.mlp = JiTSwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop)
|
| 253 |
+
|
| 254 |
+
self.act = nn.SiLU()
|
| 255 |
+
self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 256 |
+
|
| 257 |
+
def forward(self, x, c, feat_rope=None):
|
| 258 |
+
# Apply activation
|
| 259 |
+
c = self.act(c)
|
| 260 |
+
|
| 261 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
| 262 |
+
|
| 263 |
+
# Attention block
|
| 264 |
+
norm_x = self.norm1(x)
|
| 265 |
+
modulated_x = modulate(norm_x, shift_msa, scale_msa)
|
| 266 |
+
attn_out = self.attn(modulated_x, rope=feat_rope)
|
| 267 |
+
x = x + gate_msa.unsqueeze(1) * attn_out
|
| 268 |
+
|
| 269 |
+
# MLP block
|
| 270 |
+
norm_x = self.norm2(x)
|
| 271 |
+
modulated_x = modulate(norm_x, shift_mlp, scale_mlp)
|
| 272 |
+
mlp_out = self.mlp(modulated_x)
|
| 273 |
+
x = x + gate_mlp.unsqueeze(1) * mlp_out
|
| 274 |
+
|
| 275 |
+
return x
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
| 279 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 280 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 281 |
+
grid = np.meshgrid(grid_w, grid_h)
|
| 282 |
+
grid = np.stack(grid, axis=0)
|
| 283 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 284 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 285 |
+
if cls_token and extra_tokens > 0:
|
| 286 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 287 |
+
return pos_embed
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 291 |
+
if embed_dim % 2 != 0:
|
| 292 |
+
raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
|
| 293 |
+
|
| 294 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
|
| 295 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
|
| 296 |
+
emb = np.concatenate([emb_h, emb_w], axis=1)
|
| 297 |
+
return emb
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 301 |
+
if embed_dim % 2 != 0:
|
| 302 |
+
raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
|
| 303 |
+
|
| 304 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 305 |
+
omega /= embed_dim / 2.0
|
| 306 |
+
omega = 1.0 / 10000**omega
|
| 307 |
+
|
| 308 |
+
pos = pos.reshape(-1)
|
| 309 |
+
out = np.einsum("m,d->md", pos, omega)
|
| 310 |
+
|
| 311 |
+
emb_sin = np.sin(out)
|
| 312 |
+
emb_cos = np.cos(out)
|
| 313 |
+
|
| 314 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1)
|
| 315 |
+
return emb
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class JiTTransformer2DModel(ModelMixin, ConfigMixin):
|
| 319 |
+
r"""
|
| 320 |
+
A 2D Transformer for pixel-space class-conditional generation with JiT
|
| 321 |
+
([Back to Basics: Let Denoising Generative Models Denoise](https://arxiv.org/abs/2511.13720)).
|
| 322 |
+
|
| 323 |
+
Parameters:
|
| 324 |
+
sample_size (`int`, defaults to `256`):
|
| 325 |
+
Input image resolution (height and width).
|
| 326 |
+
patch_size (`int`, defaults to `16`):
|
| 327 |
+
Patch size for the bottleneck patch embedder.
|
| 328 |
+
in_channels (`int`, defaults to `3`):
|
| 329 |
+
Number of input image channels.
|
| 330 |
+
hidden_size (`int`, defaults to `768`):
|
| 331 |
+
Transformer hidden dimension.
|
| 332 |
+
num_layers (`int`, defaults to `12`):
|
| 333 |
+
Number of JiT transformer blocks.
|
| 334 |
+
num_attention_heads (`int`, defaults to `12`):
|
| 335 |
+
Number of attention heads per block.
|
| 336 |
+
mlp_ratio (`float`, defaults to `4.0`):
|
| 337 |
+
MLP hidden dimension multiplier.
|
| 338 |
+
attention_dropout (`float`, defaults to `0.0`):
|
| 339 |
+
Attention dropout in the middle quarter of blocks.
|
| 340 |
+
dropout (`float`, defaults to `0.0`):
|
| 341 |
+
Projection dropout in the middle quarter of blocks.
|
| 342 |
+
num_classes (`int`, defaults to `1000`):
|
| 343 |
+
Number of class labels (null label uses index `num_classes` for CFG).
|
| 344 |
+
bottleneck_dim (`int`, defaults to `128`):
|
| 345 |
+
PCA bottleneck dimension in the patch embedder.
|
| 346 |
+
in_context_len (`int`, defaults to `32`):
|
| 347 |
+
Number of in-context class tokens prepended mid-network.
|
| 348 |
+
in_context_start (`int`, defaults to `4`):
|
| 349 |
+
Block index at which in-context tokens are inserted.
|
| 350 |
+
norm_eps (`float`, defaults to `1e-6`):
|
| 351 |
+
Epsilon for RMSNorm layers.
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
_supports_gradient_checkpointing = True
|
| 355 |
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
| 356 |
+
|
| 357 |
+
@register_to_config
|
| 358 |
+
def __init__(
|
| 359 |
+
self,
|
| 360 |
+
sample_size: int = 256,
|
| 361 |
+
patch_size: int = 16,
|
| 362 |
+
in_channels: int = 3,
|
| 363 |
+
hidden_size: int = 768,
|
| 364 |
+
num_layers: int = 12,
|
| 365 |
+
num_attention_heads: int = 12,
|
| 366 |
+
mlp_ratio: float = 4.0,
|
| 367 |
+
attention_dropout: float = 0.0,
|
| 368 |
+
dropout: float = 0.0,
|
| 369 |
+
num_classes: int = 1000,
|
| 370 |
+
bottleneck_dim: int = 128,
|
| 371 |
+
in_context_len: int = 32,
|
| 372 |
+
in_context_start: int = 4,
|
| 373 |
+
norm_eps: float = 1e-6,
|
| 374 |
+
):
|
| 375 |
+
super().__init__()
|
| 376 |
+
self.sample_size = sample_size
|
| 377 |
+
self.patch_size = patch_size
|
| 378 |
+
self.in_channels = in_channels
|
| 379 |
+
self.out_channels = in_channels
|
| 380 |
+
self.hidden_size = hidden_size
|
| 381 |
+
self.num_layers = num_layers
|
| 382 |
+
self.num_attention_heads = num_attention_heads
|
| 383 |
+
self.in_context_len = in_context_len
|
| 384 |
+
self.in_context_start = in_context_start
|
| 385 |
+
self.norm_eps = norm_eps
|
| 386 |
+
self.gradient_checkpointing = False
|
| 387 |
+
|
| 388 |
+
# Time and Class Embedding
|
| 389 |
+
self.t_embedder = JiTTimestepEmbedder(hidden_size)
|
| 390 |
+
self.y_embedder = JiTLabelEmbedder(num_classes, hidden_size)
|
| 391 |
+
|
| 392 |
+
# Patch Embedding
|
| 393 |
+
self.x_embedder = JiTPatchEmbed(
|
| 394 |
+
img_size=sample_size,
|
| 395 |
+
patch_size=patch_size,
|
| 396 |
+
in_chans=in_channels,
|
| 397 |
+
pca_dim=bottleneck_dim,
|
| 398 |
+
embed_dim=hidden_size,
|
| 399 |
+
bias=True,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# Positional Embedding (Fixed Sin-Cos)
|
| 403 |
+
num_patches = self.x_embedder.num_patches
|
| 404 |
+
pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches**0.5))
|
| 405 |
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
|
| 406 |
+
|
| 407 |
+
# In-context Embedding
|
| 408 |
+
if self.in_context_len > 0:
|
| 409 |
+
self.in_context_posemb = nn.Parameter(torch.zeros(1, self.in_context_len, hidden_size))
|
| 410 |
+
|
| 411 |
+
# RoPE
|
| 412 |
+
half_head_dim = hidden_size // num_attention_heads // 2
|
| 413 |
+
hw_seq_len = sample_size // patch_size
|
| 414 |
+
self.feat_rope = JiTRotaryEmbedding(dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=0)
|
| 415 |
+
self.feat_rope_incontext = JiTRotaryEmbedding(
|
| 416 |
+
dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=self.in_context_len
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
# Blocks
|
| 420 |
+
self.blocks = nn.ModuleList(
|
| 421 |
+
[
|
| 422 |
+
JiTBlock(
|
| 423 |
+
hidden_size,
|
| 424 |
+
num_attention_heads,
|
| 425 |
+
mlp_ratio=mlp_ratio,
|
| 426 |
+
attn_drop=attention_dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
|
| 427 |
+
proj_drop=dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
|
| 428 |
+
eps=norm_eps,
|
| 429 |
+
)
|
| 430 |
+
for i in range(num_layers)
|
| 431 |
+
]
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# Final Layer
|
| 435 |
+
self.norm_final = RMSNorm(hidden_size, eps=norm_eps)
|
| 436 |
+
self.linear_final = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
|
| 437 |
+
self.act_final = nn.SiLU()
|
| 438 |
+
self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 439 |
+
|
| 440 |
+
def forward(
|
| 441 |
+
self,
|
| 442 |
+
hidden_states: torch.Tensor,
|
| 443 |
+
timestep: torch.LongTensor,
|
| 444 |
+
class_labels: torch.LongTensor,
|
| 445 |
+
return_dict: bool = True,
|
| 446 |
+
):
|
| 447 |
+
|
| 448 |
+
t_emb = self.t_embedder(timestep, dtype=hidden_states.dtype)
|
| 449 |
+
y_emb = self.y_embedder(class_labels)
|
| 450 |
+
|
| 451 |
+
# Ensure embeddings match hidden_states dtype
|
| 452 |
+
y_emb = y_emb.to(dtype=hidden_states.dtype)
|
| 453 |
+
|
| 454 |
+
c = t_emb + y_emb
|
| 455 |
+
|
| 456 |
+
# Patch Embed
|
| 457 |
+
x = self.x_embedder(hidden_states)
|
| 458 |
+
x = x + self.pos_embed.to(x.dtype)
|
| 459 |
+
|
| 460 |
+
# Blocks
|
| 461 |
+
for i, block in enumerate(self.blocks):
|
| 462 |
+
if self.in_context_len > 0 and i == self.in_context_start:
|
| 463 |
+
in_context_tokens = y_emb.unsqueeze(1).repeat(1, self.in_context_len, 1)
|
| 464 |
+
in_context_tokens = in_context_tokens + self.in_context_posemb.to(in_context_tokens.dtype)
|
| 465 |
+
x = torch.cat([in_context_tokens, x], dim=1)
|
| 466 |
+
|
| 467 |
+
rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
|
| 468 |
+
|
| 469 |
+
if self.training and self.gradient_checkpointing:
|
| 470 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 471 |
+
block,
|
| 472 |
+
x,
|
| 473 |
+
c,
|
| 474 |
+
rope,
|
| 475 |
+
use_reentrant=False,
|
| 476 |
+
)
|
| 477 |
+
else:
|
| 478 |
+
x = block(x, c, feat_rope=rope)
|
| 479 |
+
|
| 480 |
+
# Slice off in-context tokens
|
| 481 |
+
if self.in_context_len > 0:
|
| 482 |
+
x = x[:, self.in_context_len :]
|
| 483 |
+
|
| 484 |
+
# Final Layer
|
| 485 |
+
c = self.act_final(c)
|
| 486 |
+
shift, scale = self.adaLN_modulation_final(c).chunk(2, dim=1)
|
| 487 |
+
|
| 488 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 489 |
+
x = self.linear_final(x)
|
| 490 |
+
|
| 491 |
+
# Unpatchify
|
| 492 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 493 |
+
x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels))
|
| 494 |
+
x = torch.einsum("nhwpqc->nchpwq", x)
|
| 495 |
+
output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size))
|
| 496 |
+
|
| 497 |
+
if not return_dict:
|
| 498 |
+
return (output,)
|
| 499 |
+
|
| 500 |
+
return Transformer2DModelOutput(sample=output)
|
JiT-L-16/model_index.json
CHANGED
|
@@ -1,8 +1,15 @@
|
|
| 1 |
{
|
| 2 |
-
"_class_name":
|
|
|
|
|
|
|
|
|
|
| 3 |
"_diffusers_version": "0.36.0",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"transformer": [
|
| 5 |
-
"
|
| 6 |
"JiTTransformer2DModel"
|
| 7 |
]
|
| 8 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"_class_name": [
|
| 3 |
+
"pipeline",
|
| 4 |
+
"JiTPipeline"
|
| 5 |
+
],
|
| 6 |
"_diffusers_version": "0.36.0",
|
| 7 |
+
"scheduler": [
|
| 8 |
+
"scheduling_jit",
|
| 9 |
+
"JiTScheduler"
|
| 10 |
+
],
|
| 11 |
"transformer": [
|
| 12 |
+
"jit_transformer_2d",
|
| 13 |
"JiTTransformer2DModel"
|
| 14 |
]
|
| 15 |
}
|
JiT-L-16/pipeline.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import importlib
|
| 18 |
+
import json
|
| 19 |
+
import sys
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 26 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
RECOMMENDED_NOISE_BY_SIZE = {
|
| 30 |
+
256: 1.0,
|
| 31 |
+
512: 2.0,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class JiTPipeline(DiffusionPipeline):
|
| 36 |
+
r"""
|
| 37 |
+
Pipeline for image generation using JiT (Just image Transformer).
|
| 38 |
+
|
| 39 |
+
Parameters:
|
| 40 |
+
transformer ([`JiTTransformer2DModel`]):
|
| 41 |
+
A class-conditioned `JiTTransformer2DModel` to denoise the images.
|
| 42 |
+
scheduler ([`JiTScheduler`]):
|
| 43 |
+
Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler).
|
| 44 |
+
id2label (`dict[int, str]`, *optional*):
|
| 45 |
+
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
|
| 46 |
+
id2label_cn (`dict[int, str]`, *optional*):
|
| 47 |
+
ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
model_cpu_offload_seq = "transformer"
|
| 51 |
+
|
| 52 |
+
@classmethod
|
| 53 |
+
def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
|
| 54 |
+
"""Load a self-contained variant folder locally or from the Hub.
|
| 55 |
+
|
| 56 |
+
Examples:
|
| 57 |
+
JiTPipeline.from_pretrained(".")
|
| 58 |
+
JiTPipeline.from_pretrained("./JiT-H-32")
|
| 59 |
+
DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
|
| 60 |
+
"""
|
| 61 |
+
repo_root = Path(__file__).resolve().parent
|
| 62 |
+
|
| 63 |
+
if pretrained_model_name_or_path in (None, "", "."):
|
| 64 |
+
variant = repo_root
|
| 65 |
+
elif (
|
| 66 |
+
isinstance(pretrained_model_name_or_path, str)
|
| 67 |
+
and "/" in pretrained_model_name_or_path
|
| 68 |
+
and not Path(pretrained_model_name_or_path).exists()
|
| 69 |
+
):
|
| 70 |
+
from huggingface_hub import snapshot_download
|
| 71 |
+
|
| 72 |
+
hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
|
| 73 |
+
if subfolder:
|
| 74 |
+
hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"])
|
| 75 |
+
cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
|
| 76 |
+
variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
|
| 77 |
+
else:
|
| 78 |
+
variant = Path(pretrained_model_name_or_path)
|
| 79 |
+
if not variant.is_absolute():
|
| 80 |
+
candidate = (Path.cwd() / variant).resolve()
|
| 81 |
+
variant = candidate if candidate.exists() else (repo_root / variant).resolve()
|
| 82 |
+
if subfolder:
|
| 83 |
+
variant = variant / subfolder
|
| 84 |
+
|
| 85 |
+
model_kwargs = dict(kwargs)
|
| 86 |
+
inserted: List[str] = []
|
| 87 |
+
|
| 88 |
+
def _load_component(folder: str, module_name: str, class_name: str):
|
| 89 |
+
comp_dir = variant / folder
|
| 90 |
+
module_path = comp_dir / f"{module_name}.py"
|
| 91 |
+
has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
|
| 92 |
+
if not module_path.exists() or not has_weights:
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
comp_path = str(comp_dir)
|
| 96 |
+
if comp_path not in sys.path:
|
| 97 |
+
sys.path.insert(0, comp_path)
|
| 98 |
+
inserted.append(comp_path)
|
| 99 |
+
|
| 100 |
+
module = importlib.import_module(module_name)
|
| 101 |
+
component_cls = getattr(module, class_name)
|
| 102 |
+
return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
|
| 106 |
+
scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler")
|
| 107 |
+
|
| 108 |
+
if transformer is None:
|
| 109 |
+
raise ValueError(f"No loadable transformer found under {variant}")
|
| 110 |
+
|
| 111 |
+
variant_path = str(variant)
|
| 112 |
+
id2label, id2label_cn = cls._load_labels_for_variant(variant_path)
|
| 113 |
+
|
| 114 |
+
pipe = cls(
|
| 115 |
+
transformer=transformer,
|
| 116 |
+
scheduler=scheduler,
|
| 117 |
+
id2label=id2label,
|
| 118 |
+
id2label_cn=id2label_cn,
|
| 119 |
+
)
|
| 120 |
+
if variant_path and hasattr(pipe, "register_to_config"):
|
| 121 |
+
pipe.register_to_config(_name_or_path=variant_path)
|
| 122 |
+
return pipe
|
| 123 |
+
finally:
|
| 124 |
+
for comp_path in inserted:
|
| 125 |
+
if comp_path in sys.path:
|
| 126 |
+
sys.path.remove(comp_path)
|
| 127 |
+
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
transformer,
|
| 131 |
+
scheduler,
|
| 132 |
+
id2label: Optional[Dict[int, str]] = None,
|
| 133 |
+
id2label_cn: Optional[Dict[int, str]] = None,
|
| 134 |
+
):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.register_modules(transformer=transformer, scheduler=scheduler)
|
| 137 |
+
|
| 138 |
+
self._id2label = id2label or {}
|
| 139 |
+
self._id2label_cn = id2label_cn or {}
|
| 140 |
+
self.labels = self._build_label2id(self._id2label)
|
| 141 |
+
self.labels_cn = self._build_label2id(self._id2label_cn)
|
| 142 |
+
|
| 143 |
+
def _ensure_labels_loaded(self) -> None:
|
| 144 |
+
if self._id2label or self._id2label_cn:
|
| 145 |
+
return
|
| 146 |
+
loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None))
|
| 147 |
+
if loaded_en:
|
| 148 |
+
self._id2label = loaded_en
|
| 149 |
+
self.labels = self._build_label2id(self._id2label)
|
| 150 |
+
if loaded_cn:
|
| 151 |
+
self._id2label_cn = loaded_cn
|
| 152 |
+
self.labels_cn = self._build_label2id(self._id2label_cn)
|
| 153 |
+
|
| 154 |
+
@staticmethod
|
| 155 |
+
def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]:
|
| 156 |
+
if not variant_path:
|
| 157 |
+
return None
|
| 158 |
+
variant_dir = Path(variant_path).resolve()
|
| 159 |
+
labels_dir = variant_dir.parent / "labels"
|
| 160 |
+
return labels_dir if labels_dir.is_dir() else None
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]:
|
| 164 |
+
filename = "id2label_en.json" if lang == "en" else "id2label_cn.json"
|
| 165 |
+
path = labels_dir / filename
|
| 166 |
+
if not path.exists():
|
| 167 |
+
raise FileNotFoundError(path)
|
| 168 |
+
raw = json.loads(path.read_text(encoding="utf-8"))
|
| 169 |
+
return {int(key): value for key, value in raw.items()}
|
| 170 |
+
|
| 171 |
+
@classmethod
|
| 172 |
+
def _load_labels_for_variant(
|
| 173 |
+
cls,
|
| 174 |
+
variant_path: Optional[str],
|
| 175 |
+
) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]:
|
| 176 |
+
labels_dir = cls._labels_dir_for_variant(variant_path)
|
| 177 |
+
if labels_dir is None:
|
| 178 |
+
return None, None
|
| 179 |
+
try:
|
| 180 |
+
return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn")
|
| 181 |
+
except FileNotFoundError:
|
| 182 |
+
return None, None
|
| 183 |
+
|
| 184 |
+
@staticmethod
|
| 185 |
+
def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
|
| 186 |
+
label2id: Dict[str, int] = {}
|
| 187 |
+
for class_id, value in id2label.items():
|
| 188 |
+
for synonym in value.split(","):
|
| 189 |
+
synonym = synonym.strip()
|
| 190 |
+
if synonym:
|
| 191 |
+
label2id[synonym] = int(class_id)
|
| 192 |
+
return dict(sorted(label2id.items()))
|
| 193 |
+
|
| 194 |
+
@property
|
| 195 |
+
def id2label(self) -> Dict[int, str]:
|
| 196 |
+
"""ImageNet class id to English label string (comma-separated synonyms)."""
|
| 197 |
+
self._ensure_labels_loaded()
|
| 198 |
+
return self._id2label
|
| 199 |
+
|
| 200 |
+
@property
|
| 201 |
+
def id2label_cn(self) -> Dict[int, str]:
|
| 202 |
+
"""ImageNet class id to Chinese label string (comma-separated synonyms)."""
|
| 203 |
+
self._ensure_labels_loaded()
|
| 204 |
+
return self._id2label_cn
|
| 205 |
+
|
| 206 |
+
def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]:
|
| 207 |
+
r"""
|
| 208 |
+
Map ImageNet label strings to class ids.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
label (`str` or `list[str]`):
|
| 212 |
+
One or more label strings. Each string must match a synonym in `id2label` (English)
|
| 213 |
+
or `id2label_cn` (Chinese).
|
| 214 |
+
lang (`str`, *optional*, defaults to `"en"`):
|
| 215 |
+
`"en"` uses English synonyms; `"cn"` uses Chinese synonyms.
|
| 216 |
+
"""
|
| 217 |
+
if lang not in ("en", "cn"):
|
| 218 |
+
raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.")
|
| 219 |
+
|
| 220 |
+
self._ensure_labels_loaded()
|
| 221 |
+
label2id = self.labels if lang == "en" else self.labels_cn
|
| 222 |
+
if not label2id:
|
| 223 |
+
raise ValueError(
|
| 224 |
+
f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if isinstance(label, str):
|
| 228 |
+
label = [label]
|
| 229 |
+
|
| 230 |
+
missing = [item for item in label if item not in label2id]
|
| 231 |
+
if missing:
|
| 232 |
+
preview = ", ".join(list(label2id.keys())[:8])
|
| 233 |
+
raise ValueError(
|
| 234 |
+
f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
|
| 235 |
+
)
|
| 236 |
+
return [label2id[item] for item in label]
|
| 237 |
+
|
| 238 |
+
def _normalize_class_labels(
|
| 239 |
+
self,
|
| 240 |
+
class_labels: Union[int, str, List[Union[int, str]]],
|
| 241 |
+
) -> List[int]:
|
| 242 |
+
if isinstance(class_labels, int):
|
| 243 |
+
return [class_labels]
|
| 244 |
+
|
| 245 |
+
if isinstance(class_labels, str):
|
| 246 |
+
return self.get_label_ids(class_labels)
|
| 247 |
+
|
| 248 |
+
if class_labels and isinstance(class_labels[0], str):
|
| 249 |
+
self._ensure_labels_loaded()
|
| 250 |
+
if all(label in self.labels for label in class_labels):
|
| 251 |
+
return self.get_label_ids(class_labels, lang="en")
|
| 252 |
+
if all(label in self.labels_cn for label in class_labels):
|
| 253 |
+
return self.get_label_ids(class_labels, lang="cn")
|
| 254 |
+
raise ValueError(
|
| 255 |
+
"Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` "
|
| 256 |
+
"or Chinese synonyms from `pipe.labels_cn`."
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
return list(class_labels)
|
| 260 |
+
|
| 261 |
+
def _predict_velocity(
|
| 262 |
+
self,
|
| 263 |
+
z_value: torch.Tensor,
|
| 264 |
+
t: torch.Tensor,
|
| 265 |
+
class_labels: torch.Tensor,
|
| 266 |
+
class_null: torch.Tensor,
|
| 267 |
+
do_classifier_free_guidance: bool,
|
| 268 |
+
guidance_scale: float,
|
| 269 |
+
guidance_interval_min: float,
|
| 270 |
+
guidance_interval_max: float,
|
| 271 |
+
) -> torch.Tensor:
|
| 272 |
+
t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype)
|
| 273 |
+
if do_classifier_free_guidance:
|
| 274 |
+
z_in = torch.cat([z_value, z_value], dim=0)
|
| 275 |
+
labels = torch.cat([class_labels, class_null], dim=0)
|
| 276 |
+
else:
|
| 277 |
+
z_in = z_value
|
| 278 |
+
labels = class_labels
|
| 279 |
+
|
| 280 |
+
t_batch = t.flatten().expand(z_in.shape[0])
|
| 281 |
+
x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample
|
| 282 |
+
v = self.scheduler.velocity_from_prediction(z_in, x_pred, t)
|
| 283 |
+
|
| 284 |
+
if not do_classifier_free_guidance:
|
| 285 |
+
return v
|
| 286 |
+
|
| 287 |
+
v_cond, v_uncond = v.chunk(2, dim=0)
|
| 288 |
+
interval_mask = t < guidance_interval_max
|
| 289 |
+
if guidance_interval_min != 0.0:
|
| 290 |
+
interval_mask = interval_mask & (t > guidance_interval_min)
|
| 291 |
+
scale = torch.where(
|
| 292 |
+
interval_mask,
|
| 293 |
+
torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype),
|
| 294 |
+
torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype),
|
| 295 |
+
)
|
| 296 |
+
return v_uncond + scale * (v_cond - v_uncond)
|
| 297 |
+
|
| 298 |
+
def _run_sampler(
|
| 299 |
+
self,
|
| 300 |
+
latents: torch.Tensor,
|
| 301 |
+
class_labels: torch.Tensor,
|
| 302 |
+
class_null: torch.Tensor,
|
| 303 |
+
num_inference_steps: int,
|
| 304 |
+
do_classifier_free_guidance: bool,
|
| 305 |
+
guidance_scale: float,
|
| 306 |
+
guidance_interval_min: float,
|
| 307 |
+
guidance_interval_max: float,
|
| 308 |
+
sampling_method: str,
|
| 309 |
+
) -> torch.Tensor:
|
| 310 |
+
device = latents.device
|
| 311 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method)
|
| 312 |
+
timesteps = self.scheduler.timesteps
|
| 313 |
+
|
| 314 |
+
for i in self.progress_bar(range(num_inference_steps - 1)):
|
| 315 |
+
t = timesteps[i]
|
| 316 |
+
t_next = timesteps[i + 1]
|
| 317 |
+
v = self._predict_velocity(
|
| 318 |
+
latents,
|
| 319 |
+
t,
|
| 320 |
+
class_labels,
|
| 321 |
+
class_null,
|
| 322 |
+
do_classifier_free_guidance,
|
| 323 |
+
guidance_scale,
|
| 324 |
+
guidance_interval_min,
|
| 325 |
+
guidance_interval_max,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
if sampling_method == "heun":
|
| 329 |
+
latents_euler = latents + (t_next - t) * v
|
| 330 |
+
v_next = self._predict_velocity(
|
| 331 |
+
latents_euler,
|
| 332 |
+
t_next,
|
| 333 |
+
class_labels,
|
| 334 |
+
class_null,
|
| 335 |
+
do_classifier_free_guidance,
|
| 336 |
+
guidance_scale,
|
| 337 |
+
guidance_interval_min,
|
| 338 |
+
guidance_interval_max,
|
| 339 |
+
)
|
| 340 |
+
latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample
|
| 341 |
+
else:
|
| 342 |
+
latents = self.scheduler.step(v, t, latents).prev_sample
|
| 343 |
+
|
| 344 |
+
t = timesteps[-2]
|
| 345 |
+
t_next = timesteps[-1]
|
| 346 |
+
v = self._predict_velocity(
|
| 347 |
+
latents,
|
| 348 |
+
t,
|
| 349 |
+
class_labels,
|
| 350 |
+
class_null,
|
| 351 |
+
do_classifier_free_guidance,
|
| 352 |
+
guidance_scale,
|
| 353 |
+
guidance_interval_min,
|
| 354 |
+
guidance_interval_max,
|
| 355 |
+
)
|
| 356 |
+
return latents + (t_next - t) * v
|
| 357 |
+
|
| 358 |
+
@torch.inference_mode()
|
| 359 |
+
def __call__(
|
| 360 |
+
self,
|
| 361 |
+
class_labels: Union[int, str, List[Union[int, str]]],
|
| 362 |
+
guidance_scale: Optional[float] = None,
|
| 363 |
+
guidance_interval_min: float = 0.1,
|
| 364 |
+
guidance_interval_max: float = 1.0,
|
| 365 |
+
noise_scale: Optional[float] = None,
|
| 366 |
+
t_eps: Optional[float] = None,
|
| 367 |
+
sampling_method: Optional[str] = None,
|
| 368 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 369 |
+
num_inference_steps: int = 50,
|
| 370 |
+
output_type: Optional[str] = "pil",
|
| 371 |
+
return_dict: bool = True,
|
| 372 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
| 373 |
+
r"""
|
| 374 |
+
Generate class-conditional images.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
class_labels (`int`, `str`, `list[int]`, or `list[str]`):
|
| 378 |
+
ImageNet class indices or human-readable label strings (English or Chinese).
|
| 379 |
+
guidance_scale (`float`, *optional*):
|
| 380 |
+
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
|
| 381 |
+
guidance_interval_min (`float`, defaults to `0.1`):
|
| 382 |
+
Lower bound of the CFG interval in flow time `t in [0, 1]`.
|
| 383 |
+
guidance_interval_max (`float`, defaults to `1.0`):
|
| 384 |
+
Upper bound of the CFG interval in flow time.
|
| 385 |
+
noise_scale (`float`, *optional*):
|
| 386 |
+
Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
|
| 387 |
+
t_eps (`float`, *optional*):
|
| 388 |
+
Epsilon clamp for the `1 - t` denominator (scheduler config by default).
|
| 389 |
+
sampling_method (`str`, *optional*):
|
| 390 |
+
`"heun"` or `"euler"`. Defaults to the scheduler config (`heun`).
|
| 391 |
+
generator (`torch.Generator`, *optional*):
|
| 392 |
+
RNG for reproducibility.
|
| 393 |
+
num_inference_steps (`int`, defaults to `50`):
|
| 394 |
+
Number of solver steps (at least 2).
|
| 395 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 396 |
+
`"pil"`, `"np"`, or `"pt"`.
|
| 397 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 398 |
+
Return [`ImagePipelineOutput`] if True.
|
| 399 |
+
"""
|
| 400 |
+
solver = sampling_method or self.scheduler.config.solver
|
| 401 |
+
if solver not in {"heun", "euler"}:
|
| 402 |
+
raise ValueError("sampling_method must be one of: 'heun', 'euler'.")
|
| 403 |
+
if num_inference_steps < 2:
|
| 404 |
+
raise ValueError("num_inference_steps must be >= 2.")
|
| 405 |
+
|
| 406 |
+
if t_eps is not None:
|
| 407 |
+
self.scheduler.register_to_config(t_eps=t_eps)
|
| 408 |
+
|
| 409 |
+
class_label_ids = self._normalize_class_labels(class_labels)
|
| 410 |
+
do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
|
| 411 |
+
|
| 412 |
+
batch_size = len(class_label_ids)
|
| 413 |
+
image_size = int(self.transformer.config.sample_size)
|
| 414 |
+
channels = int(self.transformer.config.in_channels)
|
| 415 |
+
null_class_val = int(self.transformer.config.num_classes)
|
| 416 |
+
|
| 417 |
+
if guidance_scale is None:
|
| 418 |
+
guidance_scale = 1.0
|
| 419 |
+
if noise_scale is None:
|
| 420 |
+
noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0)
|
| 421 |
+
|
| 422 |
+
latents = (
|
| 423 |
+
randn_tensor(
|
| 424 |
+
shape=(batch_size, channels, image_size, image_size),
|
| 425 |
+
generator=generator,
|
| 426 |
+
device=self._execution_device,
|
| 427 |
+
dtype=self.transformer.dtype,
|
| 428 |
+
)
|
| 429 |
+
* noise_scale
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
|
| 433 |
+
class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
|
| 434 |
+
class_null = torch.full_like(class_labels_t, null_class_val)
|
| 435 |
+
|
| 436 |
+
latents = self._run_sampler(
|
| 437 |
+
latents,
|
| 438 |
+
class_labels_t,
|
| 439 |
+
class_null,
|
| 440 |
+
num_inference_steps,
|
| 441 |
+
do_classifier_free_guidance,
|
| 442 |
+
guidance_scale,
|
| 443 |
+
guidance_interval_min,
|
| 444 |
+
guidance_interval_max,
|
| 445 |
+
solver,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
|
| 449 |
+
if output_type == "pt":
|
| 450 |
+
images = images_pt
|
| 451 |
+
elif output_type == "np":
|
| 452 |
+
images = images_pt.permute(0, 2, 3, 1).numpy()
|
| 453 |
+
else:
|
| 454 |
+
images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy())
|
| 455 |
+
|
| 456 |
+
self.maybe_free_model_hooks()
|
| 457 |
+
|
| 458 |
+
if not return_dict:
|
| 459 |
+
return (images,)
|
| 460 |
+
return ImagePipelineOutput(images=images)
|
JiT-L-16/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "JiTScheduler",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"num_train_timesteps": 1000,
|
| 5 |
+
"t_eps": 0.05,
|
| 6 |
+
"solver": "heun"
|
| 7 |
+
}
|
JiT-L-16/scheduler/scheduling_jit.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 21 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
| 22 |
+
from diffusers.utils import BaseOutput
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class JiTSchedulerOutput(BaseOutput):
|
| 27 |
+
"""
|
| 28 |
+
Output class for the JiT scheduler's `step` function.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
prev_sample (`torch.Tensor`):
|
| 32 |
+
Updated sample after one solver step along the JiT flow-time grid.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
prev_sample: torch.Tensor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class JiTScheduler(SchedulerMixin, ConfigMixin):
|
| 39 |
+
"""
|
| 40 |
+
Manual flow-matching scheduler for JiT checkpoints.
|
| 41 |
+
|
| 42 |
+
Uses a linear flow-time grid `t in [0, 1]` (increasing), matching the official JiT
|
| 43 |
+
sampler. Velocity is `v = (x_pred - z) / (1 - t)`; integration is explicit Euler or
|
| 44 |
+
Heun along that grid.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
order = 2
|
| 48 |
+
|
| 49 |
+
@register_to_config
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
num_train_timesteps: int = 1000,
|
| 53 |
+
t_eps: float = 5e-2,
|
| 54 |
+
solver: str = "heun",
|
| 55 |
+
):
|
| 56 |
+
if solver not in {"heun", "euler"}:
|
| 57 |
+
raise ValueError("solver must be one of: 'heun', 'euler'.")
|
| 58 |
+
self.timesteps: Optional[torch.Tensor] = None
|
| 59 |
+
self.sigmas: Optional[List[float]] = None
|
| 60 |
+
self.num_inference_steps: Optional[int] = None
|
| 61 |
+
self._step_index: Optional[int] = None
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def init_noise_sigma(self) -> float:
|
| 65 |
+
return 1.0
|
| 66 |
+
|
| 67 |
+
def set_timesteps(
|
| 68 |
+
self,
|
| 69 |
+
num_inference_steps: int,
|
| 70 |
+
device: Union[str, torch.device, None] = None,
|
| 71 |
+
solver: Optional[str] = None,
|
| 72 |
+
) -> None:
|
| 73 |
+
if num_inference_steps < 2:
|
| 74 |
+
raise ValueError("num_inference_steps must be >= 2.")
|
| 75 |
+
|
| 76 |
+
self.num_inference_steps = num_inference_steps
|
| 77 |
+
self.timesteps = torch.linspace(
|
| 78 |
+
0.0,
|
| 79 |
+
1.0,
|
| 80 |
+
num_inference_steps + 1,
|
| 81 |
+
device=device,
|
| 82 |
+
dtype=torch.float32,
|
| 83 |
+
)
|
| 84 |
+
sigma_grid = torch.linspace(0.0, 1.0, num_inference_steps, device=device, dtype=torch.float32)
|
| 85 |
+
self.sigmas = (1.0 - sigma_grid).tolist()
|
| 86 |
+
self._step_index = 0
|
| 87 |
+
if solver is not None:
|
| 88 |
+
self.register_to_config(solver=solver)
|
| 89 |
+
|
| 90 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
| 91 |
+
del timestep
|
| 92 |
+
return sample
|
| 93 |
+
|
| 94 |
+
def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int:
|
| 95 |
+
if self._step_index is not None:
|
| 96 |
+
return self._step_index
|
| 97 |
+
if self.timesteps is None:
|
| 98 |
+
raise ValueError("Call `set_timesteps` before `step`.")
|
| 99 |
+
if timestep is None:
|
| 100 |
+
return 0
|
| 101 |
+
t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0])
|
| 102 |
+
matches = (self.timesteps - t_value).abs() < 1e-6
|
| 103 |
+
if matches.any():
|
| 104 |
+
return int(matches.nonzero(as_tuple=False)[0].item())
|
| 105 |
+
return 0
|
| 106 |
+
|
| 107 |
+
def step(
|
| 108 |
+
self,
|
| 109 |
+
model_output: torch.Tensor,
|
| 110 |
+
timestep: Union[float, torch.Tensor, None],
|
| 111 |
+
sample: torch.Tensor,
|
| 112 |
+
model_output_next: Optional[torch.Tensor] = None,
|
| 113 |
+
return_dict: bool = True,
|
| 114 |
+
) -> Union[JiTSchedulerOutput, Tuple[torch.Tensor]]:
|
| 115 |
+
"""
|
| 116 |
+
Integrate one step on the linear `t` grid.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
model_output (`torch.Tensor`):
|
| 120 |
+
Velocity `v = (x_pred - z) / (1 - t)` at the current time.
|
| 121 |
+
timestep (`float` or `torch.Tensor`, *optional*):
|
| 122 |
+
Current flow time `t`. When omitted, uses the internal step index.
|
| 123 |
+
sample (`torch.Tensor`):
|
| 124 |
+
Current noisy latent `z`.
|
| 125 |
+
model_output_next (`torch.Tensor`, *optional*):
|
| 126 |
+
Velocity at `t_next` (required for Heun intermediate steps).
|
| 127 |
+
"""
|
| 128 |
+
if self.timesteps is None:
|
| 129 |
+
raise ValueError("Call `set_timesteps` before `step`.")
|
| 130 |
+
|
| 131 |
+
step_index = self._resolve_step_index(timestep)
|
| 132 |
+
if step_index >= len(self.timesteps) - 1:
|
| 133 |
+
raise ValueError("Scheduler has already reached the final timestep.")
|
| 134 |
+
|
| 135 |
+
t = self.timesteps[step_index]
|
| 136 |
+
t_next = self.timesteps[step_index + 1]
|
| 137 |
+
dt = t_next - t
|
| 138 |
+
|
| 139 |
+
if self.config.solver == "heun" and model_output_next is not None:
|
| 140 |
+
prev_sample = sample + dt * 0.5 * (model_output + model_output_next)
|
| 141 |
+
else:
|
| 142 |
+
prev_sample = sample + dt * model_output
|
| 143 |
+
|
| 144 |
+
self._step_index = step_index + 1
|
| 145 |
+
|
| 146 |
+
if not return_dict:
|
| 147 |
+
return (prev_sample,)
|
| 148 |
+
return JiTSchedulerOutput(prev_sample=prev_sample)
|
| 149 |
+
|
| 150 |
+
def velocity_from_prediction(
|
| 151 |
+
self,
|
| 152 |
+
sample: torch.Tensor,
|
| 153 |
+
x_pred: torch.Tensor,
|
| 154 |
+
timestep: Union[float, torch.Tensor],
|
| 155 |
+
) -> torch.Tensor:
|
| 156 |
+
"""Compute JiT velocity `v = (x_pred - z) / (1 - t)` with denominator clamp."""
|
| 157 |
+
t = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype)
|
| 158 |
+
while t.ndim < sample.ndim:
|
| 159 |
+
t = t.unsqueeze(-1)
|
| 160 |
+
denom = (1.0 - t).clamp_min(self.config.t_eps)
|
| 161 |
+
return (x_pred - sample) / denom
|
JiT-L-16/transformer/config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "JiTTransformer2DModel",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"attention_dropout": 0.0,
|
| 5 |
+
"bottleneck_dim": 128,
|
| 6 |
+
"dropout": 0.0,
|
| 7 |
+
"hidden_size": 1024,
|
| 8 |
+
"in_channels": 3,
|
| 9 |
+
"in_context_len": 32,
|
| 10 |
+
"in_context_start": 8,
|
| 11 |
+
"mlp_ratio": 4.0,
|
| 12 |
+
"norm_eps": 1e-06,
|
| 13 |
+
"num_attention_heads": 16,
|
| 14 |
+
"num_classes": 1000,
|
| 15 |
+
"num_layers": 24,
|
| 16 |
+
"patch_size": 16,
|
| 17 |
+
"sample_size": 256
|
| 18 |
+
}
|
JiT-L-16/transformer/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e9285393d92db078237e8adc552d6c9314c898c710ca1dfb4d3503fda0016b0f
|
| 3 |
+
size 1836593656
|
JiT-L-16/transformer/jit_transformer_2d.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 24 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 25 |
+
from diffusers.models.normalization import RMSNorm
|
| 26 |
+
from diffusers.utils import logging
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def broadcat(tensors, dim=-1):
|
| 33 |
+
num_tensors = len(tensors)
|
| 34 |
+
shape_lens = {len(t.shape) for t in tensors}
|
| 35 |
+
if len(shape_lens) != 1:
|
| 36 |
+
raise ValueError("tensors must all have the same number of dimensions")
|
| 37 |
+
shape_len = list(shape_lens)[0]
|
| 38 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
| 39 |
+
dims = list(zip(*(list(t.shape) for t in tensors)))
|
| 40 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
| 41 |
+
|
| 42 |
+
if not all(len(set(t[1])) <= 2 for t in expandable_dims):
|
| 43 |
+
raise ValueError("invalid dimensions for broadcastable concatenation")
|
| 44 |
+
|
| 45 |
+
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
| 46 |
+
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
| 47 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
| 48 |
+
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
| 49 |
+
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
| 50 |
+
return torch.cat(tensors, dim=dim)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def rotate_half(x):
|
| 54 |
+
x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2)
|
| 55 |
+
x1, x2 = x.unbind(dim=-1)
|
| 56 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 57 |
+
return x.view(*x.shape[:-2], -1)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class JiTRotaryEmbedding(nn.Module):
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
dim,
|
| 64 |
+
pt_seq_len=16,
|
| 65 |
+
ft_seq_len=None,
|
| 66 |
+
custom_freqs=None,
|
| 67 |
+
theta=10000,
|
| 68 |
+
num_cls_token=0,
|
| 69 |
+
):
|
| 70 |
+
super().__init__()
|
| 71 |
+
if custom_freqs is not None:
|
| 72 |
+
freqs = custom_freqs
|
| 73 |
+
else:
|
| 74 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 75 |
+
|
| 76 |
+
if ft_seq_len is None:
|
| 77 |
+
ft_seq_len = pt_seq_len
|
| 78 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
| 79 |
+
|
| 80 |
+
freqs = torch.einsum("..., f -> ... f", t, freqs)
|
| 81 |
+
freqs = freqs.repeat_interleave(2, dim=-1)
|
| 82 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
|
| 83 |
+
|
| 84 |
+
if num_cls_token > 0:
|
| 85 |
+
freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D]
|
| 86 |
+
cos_img = freqs_flat.cos()
|
| 87 |
+
sin_img = freqs_flat.sin()
|
| 88 |
+
|
| 89 |
+
# prepend in-context cls token
|
| 90 |
+
_, D = cos_img.shape
|
| 91 |
+
cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype)
|
| 92 |
+
sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype)
|
| 93 |
+
|
| 94 |
+
self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False)
|
| 95 |
+
self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False)
|
| 96 |
+
else:
|
| 97 |
+
self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False)
|
| 98 |
+
self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False)
|
| 99 |
+
|
| 100 |
+
def forward(self, t):
|
| 101 |
+
# Applied on (batch, seq_len, heads, head_dim) tensors from attention.
|
| 102 |
+
seq_len = t.shape[1]
|
| 103 |
+
freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
|
| 104 |
+
freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
|
| 105 |
+
|
| 106 |
+
return t * freqs_cos[:, None, :] + rotate_half(t) * freqs_sin[:, None, :]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def modulate(x, shift, scale):
|
| 110 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class JiTPatchEmbed(nn.Module):
|
| 114 |
+
"""Image to Patch Embedding with Bottleneck"""
|
| 115 |
+
|
| 116 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True):
|
| 117 |
+
super().__init__()
|
| 118 |
+
img_size = (img_size, img_size)
|
| 119 |
+
patch_size = (patch_size, patch_size)
|
| 120 |
+
self.img_size = img_size
|
| 121 |
+
self.patch_size = patch_size
|
| 122 |
+
self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
| 123 |
+
|
| 124 |
+
self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False)
|
| 125 |
+
self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias)
|
| 126 |
+
|
| 127 |
+
def forward(self, x):
|
| 128 |
+
x = self.proj2(self.proj1(x)).flatten(2).transpose(1, 2)
|
| 129 |
+
return x
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class JiTTimestepEmbedder(nn.Module):
|
| 133 |
+
"""
|
| 134 |
+
Embeds scalar timesteps into vector representations.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.mlp = nn.Sequential(
|
| 140 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 141 |
+
nn.SiLU(),
|
| 142 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 143 |
+
)
|
| 144 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 145 |
+
|
| 146 |
+
@staticmethod
|
| 147 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 148 |
+
"""
|
| 149 |
+
Create sinusoidal timestep embeddings.
|
| 150 |
+
"""
|
| 151 |
+
half = dim // 2
|
| 152 |
+
freqs = torch.exp(
|
| 153 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 154 |
+
).to(device=t.device)
|
| 155 |
+
args = t[:, None].float() * freqs[None]
|
| 156 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 157 |
+
if dim % 2:
|
| 158 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 159 |
+
return embedding
|
| 160 |
+
|
| 161 |
+
def forward(self, t, dtype=None):
|
| 162 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 163 |
+
if dtype is not None:
|
| 164 |
+
t_freq = t_freq.to(dtype=dtype)
|
| 165 |
+
t_emb = self.mlp(t_freq)
|
| 166 |
+
return t_emb
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class JiTLabelEmbedder(nn.Module):
|
| 170 |
+
"""
|
| 171 |
+
Embeds class labels into vector representations.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(self, num_classes, hidden_size):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.embedding_table = nn.Embedding(num_classes + 1, hidden_size)
|
| 177 |
+
self.num_classes = num_classes
|
| 178 |
+
|
| 179 |
+
def forward(self, labels):
|
| 180 |
+
embeddings = self.embedding_table(labels)
|
| 181 |
+
return embeddings
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class JiTAttention(nn.Module):
|
| 185 |
+
def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.num_heads = num_heads
|
| 188 |
+
head_dim = dim // num_heads
|
| 189 |
+
|
| 190 |
+
self.q_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
|
| 191 |
+
self.k_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
|
| 192 |
+
|
| 193 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 194 |
+
self.attn_drop = attn_drop
|
| 195 |
+
self.proj = nn.Linear(dim, dim)
|
| 196 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 197 |
+
|
| 198 |
+
def forward(self, x, rope=None):
|
| 199 |
+
B, N, C = x.shape
|
| 200 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 201 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 202 |
+
|
| 203 |
+
q = self.q_norm(q)
|
| 204 |
+
k = self.k_norm(k)
|
| 205 |
+
|
| 206 |
+
if rope is not None:
|
| 207 |
+
q = q.transpose(1, 2)
|
| 208 |
+
k = k.transpose(1, 2)
|
| 209 |
+
q = rope(q)
|
| 210 |
+
k = rope(k)
|
| 211 |
+
q = q.transpose(1, 2)
|
| 212 |
+
k = k.transpose(1, 2)
|
| 213 |
+
|
| 214 |
+
dropout_p = self.attn_drop if self.training else 0.0
|
| 215 |
+
x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
| 216 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 217 |
+
x = self.proj(x)
|
| 218 |
+
x = self.proj_drop(x)
|
| 219 |
+
return x
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class JiTSwiGLUFFN(nn.Module):
|
| 223 |
+
def __init__(self, dim: int, hidden_dim: int, drop=0.0, bias=True) -> None:
|
| 224 |
+
super().__init__()
|
| 225 |
+
hidden_dim = int(hidden_dim * 2 / 3)
|
| 226 |
+
self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias)
|
| 227 |
+
self.w3 = nn.Linear(hidden_dim, dim, bias=bias)
|
| 228 |
+
self.ffn_dropout = nn.Dropout(drop)
|
| 229 |
+
|
| 230 |
+
def forward(self, x):
|
| 231 |
+
x12 = self.w12(x)
|
| 232 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 233 |
+
hidden = F.silu(x1) * x2
|
| 234 |
+
return self.w3(self.ffn_dropout(hidden))
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class JiTBlock(nn.Module):
|
| 238 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
|
| 239 |
+
super().__init__()
|
| 240 |
+
self.norm1 = RMSNorm(hidden_size, eps=eps)
|
| 241 |
+
self.attn = JiTAttention(
|
| 242 |
+
hidden_size,
|
| 243 |
+
num_heads=num_heads,
|
| 244 |
+
qkv_bias=True,
|
| 245 |
+
qk_norm=True,
|
| 246 |
+
attn_drop=attn_drop,
|
| 247 |
+
proj_drop=proj_drop,
|
| 248 |
+
eps=eps,
|
| 249 |
+
)
|
| 250 |
+
self.norm2 = RMSNorm(hidden_size, eps=eps)
|
| 251 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 252 |
+
self.mlp = JiTSwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop)
|
| 253 |
+
|
| 254 |
+
self.act = nn.SiLU()
|
| 255 |
+
self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 256 |
+
|
| 257 |
+
def forward(self, x, c, feat_rope=None):
|
| 258 |
+
# Apply activation
|
| 259 |
+
c = self.act(c)
|
| 260 |
+
|
| 261 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
| 262 |
+
|
| 263 |
+
# Attention block
|
| 264 |
+
norm_x = self.norm1(x)
|
| 265 |
+
modulated_x = modulate(norm_x, shift_msa, scale_msa)
|
| 266 |
+
attn_out = self.attn(modulated_x, rope=feat_rope)
|
| 267 |
+
x = x + gate_msa.unsqueeze(1) * attn_out
|
| 268 |
+
|
| 269 |
+
# MLP block
|
| 270 |
+
norm_x = self.norm2(x)
|
| 271 |
+
modulated_x = modulate(norm_x, shift_mlp, scale_mlp)
|
| 272 |
+
mlp_out = self.mlp(modulated_x)
|
| 273 |
+
x = x + gate_mlp.unsqueeze(1) * mlp_out
|
| 274 |
+
|
| 275 |
+
return x
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
| 279 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 280 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 281 |
+
grid = np.meshgrid(grid_w, grid_h)
|
| 282 |
+
grid = np.stack(grid, axis=0)
|
| 283 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 284 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 285 |
+
if cls_token and extra_tokens > 0:
|
| 286 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 287 |
+
return pos_embed
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 291 |
+
if embed_dim % 2 != 0:
|
| 292 |
+
raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
|
| 293 |
+
|
| 294 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
|
| 295 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
|
| 296 |
+
emb = np.concatenate([emb_h, emb_w], axis=1)
|
| 297 |
+
return emb
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 301 |
+
if embed_dim % 2 != 0:
|
| 302 |
+
raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
|
| 303 |
+
|
| 304 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 305 |
+
omega /= embed_dim / 2.0
|
| 306 |
+
omega = 1.0 / 10000**omega
|
| 307 |
+
|
| 308 |
+
pos = pos.reshape(-1)
|
| 309 |
+
out = np.einsum("m,d->md", pos, omega)
|
| 310 |
+
|
| 311 |
+
emb_sin = np.sin(out)
|
| 312 |
+
emb_cos = np.cos(out)
|
| 313 |
+
|
| 314 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1)
|
| 315 |
+
return emb
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class JiTTransformer2DModel(ModelMixin, ConfigMixin):
|
| 319 |
+
r"""
|
| 320 |
+
A 2D Transformer for pixel-space class-conditional generation with JiT
|
| 321 |
+
([Back to Basics: Let Denoising Generative Models Denoise](https://arxiv.org/abs/2511.13720)).
|
| 322 |
+
|
| 323 |
+
Parameters:
|
| 324 |
+
sample_size (`int`, defaults to `256`):
|
| 325 |
+
Input image resolution (height and width).
|
| 326 |
+
patch_size (`int`, defaults to `16`):
|
| 327 |
+
Patch size for the bottleneck patch embedder.
|
| 328 |
+
in_channels (`int`, defaults to `3`):
|
| 329 |
+
Number of input image channels.
|
| 330 |
+
hidden_size (`int`, defaults to `768`):
|
| 331 |
+
Transformer hidden dimension.
|
| 332 |
+
num_layers (`int`, defaults to `12`):
|
| 333 |
+
Number of JiT transformer blocks.
|
| 334 |
+
num_attention_heads (`int`, defaults to `12`):
|
| 335 |
+
Number of attention heads per block.
|
| 336 |
+
mlp_ratio (`float`, defaults to `4.0`):
|
| 337 |
+
MLP hidden dimension multiplier.
|
| 338 |
+
attention_dropout (`float`, defaults to `0.0`):
|
| 339 |
+
Attention dropout in the middle quarter of blocks.
|
| 340 |
+
dropout (`float`, defaults to `0.0`):
|
| 341 |
+
Projection dropout in the middle quarter of blocks.
|
| 342 |
+
num_classes (`int`, defaults to `1000`):
|
| 343 |
+
Number of class labels (null label uses index `num_classes` for CFG).
|
| 344 |
+
bottleneck_dim (`int`, defaults to `128`):
|
| 345 |
+
PCA bottleneck dimension in the patch embedder.
|
| 346 |
+
in_context_len (`int`, defaults to `32`):
|
| 347 |
+
Number of in-context class tokens prepended mid-network.
|
| 348 |
+
in_context_start (`int`, defaults to `4`):
|
| 349 |
+
Block index at which in-context tokens are inserted.
|
| 350 |
+
norm_eps (`float`, defaults to `1e-6`):
|
| 351 |
+
Epsilon for RMSNorm layers.
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
_supports_gradient_checkpointing = True
|
| 355 |
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
| 356 |
+
|
| 357 |
+
@register_to_config
|
| 358 |
+
def __init__(
|
| 359 |
+
self,
|
| 360 |
+
sample_size: int = 256,
|
| 361 |
+
patch_size: int = 16,
|
| 362 |
+
in_channels: int = 3,
|
| 363 |
+
hidden_size: int = 768,
|
| 364 |
+
num_layers: int = 12,
|
| 365 |
+
num_attention_heads: int = 12,
|
| 366 |
+
mlp_ratio: float = 4.0,
|
| 367 |
+
attention_dropout: float = 0.0,
|
| 368 |
+
dropout: float = 0.0,
|
| 369 |
+
num_classes: int = 1000,
|
| 370 |
+
bottleneck_dim: int = 128,
|
| 371 |
+
in_context_len: int = 32,
|
| 372 |
+
in_context_start: int = 4,
|
| 373 |
+
norm_eps: float = 1e-6,
|
| 374 |
+
):
|
| 375 |
+
super().__init__()
|
| 376 |
+
self.sample_size = sample_size
|
| 377 |
+
self.patch_size = patch_size
|
| 378 |
+
self.in_channels = in_channels
|
| 379 |
+
self.out_channels = in_channels
|
| 380 |
+
self.hidden_size = hidden_size
|
| 381 |
+
self.num_layers = num_layers
|
| 382 |
+
self.num_attention_heads = num_attention_heads
|
| 383 |
+
self.in_context_len = in_context_len
|
| 384 |
+
self.in_context_start = in_context_start
|
| 385 |
+
self.norm_eps = norm_eps
|
| 386 |
+
self.gradient_checkpointing = False
|
| 387 |
+
|
| 388 |
+
# Time and Class Embedding
|
| 389 |
+
self.t_embedder = JiTTimestepEmbedder(hidden_size)
|
| 390 |
+
self.y_embedder = JiTLabelEmbedder(num_classes, hidden_size)
|
| 391 |
+
|
| 392 |
+
# Patch Embedding
|
| 393 |
+
self.x_embedder = JiTPatchEmbed(
|
| 394 |
+
img_size=sample_size,
|
| 395 |
+
patch_size=patch_size,
|
| 396 |
+
in_chans=in_channels,
|
| 397 |
+
pca_dim=bottleneck_dim,
|
| 398 |
+
embed_dim=hidden_size,
|
| 399 |
+
bias=True,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# Positional Embedding (Fixed Sin-Cos)
|
| 403 |
+
num_patches = self.x_embedder.num_patches
|
| 404 |
+
pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches**0.5))
|
| 405 |
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
|
| 406 |
+
|
| 407 |
+
# In-context Embedding
|
| 408 |
+
if self.in_context_len > 0:
|
| 409 |
+
self.in_context_posemb = nn.Parameter(torch.zeros(1, self.in_context_len, hidden_size))
|
| 410 |
+
|
| 411 |
+
# RoPE
|
| 412 |
+
half_head_dim = hidden_size // num_attention_heads // 2
|
| 413 |
+
hw_seq_len = sample_size // patch_size
|
| 414 |
+
self.feat_rope = JiTRotaryEmbedding(dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=0)
|
| 415 |
+
self.feat_rope_incontext = JiTRotaryEmbedding(
|
| 416 |
+
dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=self.in_context_len
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
# Blocks
|
| 420 |
+
self.blocks = nn.ModuleList(
|
| 421 |
+
[
|
| 422 |
+
JiTBlock(
|
| 423 |
+
hidden_size,
|
| 424 |
+
num_attention_heads,
|
| 425 |
+
mlp_ratio=mlp_ratio,
|
| 426 |
+
attn_drop=attention_dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
|
| 427 |
+
proj_drop=dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
|
| 428 |
+
eps=norm_eps,
|
| 429 |
+
)
|
| 430 |
+
for i in range(num_layers)
|
| 431 |
+
]
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# Final Layer
|
| 435 |
+
self.norm_final = RMSNorm(hidden_size, eps=norm_eps)
|
| 436 |
+
self.linear_final = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
|
| 437 |
+
self.act_final = nn.SiLU()
|
| 438 |
+
self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 439 |
+
|
| 440 |
+
def forward(
|
| 441 |
+
self,
|
| 442 |
+
hidden_states: torch.Tensor,
|
| 443 |
+
timestep: torch.LongTensor,
|
| 444 |
+
class_labels: torch.LongTensor,
|
| 445 |
+
return_dict: bool = True,
|
| 446 |
+
):
|
| 447 |
+
|
| 448 |
+
t_emb = self.t_embedder(timestep, dtype=hidden_states.dtype)
|
| 449 |
+
y_emb = self.y_embedder(class_labels)
|
| 450 |
+
|
| 451 |
+
# Ensure embeddings match hidden_states dtype
|
| 452 |
+
y_emb = y_emb.to(dtype=hidden_states.dtype)
|
| 453 |
+
|
| 454 |
+
c = t_emb + y_emb
|
| 455 |
+
|
| 456 |
+
# Patch Embed
|
| 457 |
+
x = self.x_embedder(hidden_states)
|
| 458 |
+
x = x + self.pos_embed.to(x.dtype)
|
| 459 |
+
|
| 460 |
+
# Blocks
|
| 461 |
+
for i, block in enumerate(self.blocks):
|
| 462 |
+
if self.in_context_len > 0 and i == self.in_context_start:
|
| 463 |
+
in_context_tokens = y_emb.unsqueeze(1).repeat(1, self.in_context_len, 1)
|
| 464 |
+
in_context_tokens = in_context_tokens + self.in_context_posemb.to(in_context_tokens.dtype)
|
| 465 |
+
x = torch.cat([in_context_tokens, x], dim=1)
|
| 466 |
+
|
| 467 |
+
rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
|
| 468 |
+
|
| 469 |
+
if self.training and self.gradient_checkpointing:
|
| 470 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 471 |
+
block,
|
| 472 |
+
x,
|
| 473 |
+
c,
|
| 474 |
+
rope,
|
| 475 |
+
use_reentrant=False,
|
| 476 |
+
)
|
| 477 |
+
else:
|
| 478 |
+
x = block(x, c, feat_rope=rope)
|
| 479 |
+
|
| 480 |
+
# Slice off in-context tokens
|
| 481 |
+
if self.in_context_len > 0:
|
| 482 |
+
x = x[:, self.in_context_len :]
|
| 483 |
+
|
| 484 |
+
# Final Layer
|
| 485 |
+
c = self.act_final(c)
|
| 486 |
+
shift, scale = self.adaLN_modulation_final(c).chunk(2, dim=1)
|
| 487 |
+
|
| 488 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 489 |
+
x = self.linear_final(x)
|
| 490 |
+
|
| 491 |
+
# Unpatchify
|
| 492 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 493 |
+
x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels))
|
| 494 |
+
x = torch.einsum("nhwpqc->nchpwq", x)
|
| 495 |
+
output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size))
|
| 496 |
+
|
| 497 |
+
if not return_dict:
|
| 498 |
+
return (output,)
|
| 499 |
+
|
| 500 |
+
return Transformer2DModelOutput(sample=output)
|
JiT-L-32/model_index.json
CHANGED
|
@@ -1,8 +1,15 @@
|
|
| 1 |
{
|
| 2 |
-
"_class_name":
|
|
|
|
|
|
|
|
|
|
| 3 |
"_diffusers_version": "0.36.0",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"transformer": [
|
| 5 |
-
"
|
| 6 |
"JiTTransformer2DModel"
|
| 7 |
]
|
| 8 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"_class_name": [
|
| 3 |
+
"pipeline",
|
| 4 |
+
"JiTPipeline"
|
| 5 |
+
],
|
| 6 |
"_diffusers_version": "0.36.0",
|
| 7 |
+
"scheduler": [
|
| 8 |
+
"scheduling_jit",
|
| 9 |
+
"JiTScheduler"
|
| 10 |
+
],
|
| 11 |
"transformer": [
|
| 12 |
+
"jit_transformer_2d",
|
| 13 |
"JiTTransformer2DModel"
|
| 14 |
]
|
| 15 |
}
|
JiT-L-32/pipeline.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import importlib
|
| 18 |
+
import json
|
| 19 |
+
import sys
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 26 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
RECOMMENDED_NOISE_BY_SIZE = {
|
| 30 |
+
256: 1.0,
|
| 31 |
+
512: 2.0,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class JiTPipeline(DiffusionPipeline):
|
| 36 |
+
r"""
|
| 37 |
+
Pipeline for image generation using JiT (Just image Transformer).
|
| 38 |
+
|
| 39 |
+
Parameters:
|
| 40 |
+
transformer ([`JiTTransformer2DModel`]):
|
| 41 |
+
A class-conditioned `JiTTransformer2DModel` to denoise the images.
|
| 42 |
+
scheduler ([`JiTScheduler`]):
|
| 43 |
+
Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler).
|
| 44 |
+
id2label (`dict[int, str]`, *optional*):
|
| 45 |
+
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
|
| 46 |
+
id2label_cn (`dict[int, str]`, *optional*):
|
| 47 |
+
ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
model_cpu_offload_seq = "transformer"
|
| 51 |
+
|
| 52 |
+
@classmethod
|
| 53 |
+
def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
|
| 54 |
+
"""Load a self-contained variant folder locally or from the Hub.
|
| 55 |
+
|
| 56 |
+
Examples:
|
| 57 |
+
JiTPipeline.from_pretrained(".")
|
| 58 |
+
JiTPipeline.from_pretrained("./JiT-H-32")
|
| 59 |
+
DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
|
| 60 |
+
"""
|
| 61 |
+
repo_root = Path(__file__).resolve().parent
|
| 62 |
+
|
| 63 |
+
if pretrained_model_name_or_path in (None, "", "."):
|
| 64 |
+
variant = repo_root
|
| 65 |
+
elif (
|
| 66 |
+
isinstance(pretrained_model_name_or_path, str)
|
| 67 |
+
and "/" in pretrained_model_name_or_path
|
| 68 |
+
and not Path(pretrained_model_name_or_path).exists()
|
| 69 |
+
):
|
| 70 |
+
from huggingface_hub import snapshot_download
|
| 71 |
+
|
| 72 |
+
hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
|
| 73 |
+
if subfolder:
|
| 74 |
+
hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"])
|
| 75 |
+
cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
|
| 76 |
+
variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
|
| 77 |
+
else:
|
| 78 |
+
variant = Path(pretrained_model_name_or_path)
|
| 79 |
+
if not variant.is_absolute():
|
| 80 |
+
candidate = (Path.cwd() / variant).resolve()
|
| 81 |
+
variant = candidate if candidate.exists() else (repo_root / variant).resolve()
|
| 82 |
+
if subfolder:
|
| 83 |
+
variant = variant / subfolder
|
| 84 |
+
|
| 85 |
+
model_kwargs = dict(kwargs)
|
| 86 |
+
inserted: List[str] = []
|
| 87 |
+
|
| 88 |
+
def _load_component(folder: str, module_name: str, class_name: str):
|
| 89 |
+
comp_dir = variant / folder
|
| 90 |
+
module_path = comp_dir / f"{module_name}.py"
|
| 91 |
+
has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
|
| 92 |
+
if not module_path.exists() or not has_weights:
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
comp_path = str(comp_dir)
|
| 96 |
+
if comp_path not in sys.path:
|
| 97 |
+
sys.path.insert(0, comp_path)
|
| 98 |
+
inserted.append(comp_path)
|
| 99 |
+
|
| 100 |
+
module = importlib.import_module(module_name)
|
| 101 |
+
component_cls = getattr(module, class_name)
|
| 102 |
+
return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
|
| 106 |
+
scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler")
|
| 107 |
+
|
| 108 |
+
if transformer is None:
|
| 109 |
+
raise ValueError(f"No loadable transformer found under {variant}")
|
| 110 |
+
|
| 111 |
+
variant_path = str(variant)
|
| 112 |
+
id2label, id2label_cn = cls._load_labels_for_variant(variant_path)
|
| 113 |
+
|
| 114 |
+
pipe = cls(
|
| 115 |
+
transformer=transformer,
|
| 116 |
+
scheduler=scheduler,
|
| 117 |
+
id2label=id2label,
|
| 118 |
+
id2label_cn=id2label_cn,
|
| 119 |
+
)
|
| 120 |
+
if variant_path and hasattr(pipe, "register_to_config"):
|
| 121 |
+
pipe.register_to_config(_name_or_path=variant_path)
|
| 122 |
+
return pipe
|
| 123 |
+
finally:
|
| 124 |
+
for comp_path in inserted:
|
| 125 |
+
if comp_path in sys.path:
|
| 126 |
+
sys.path.remove(comp_path)
|
| 127 |
+
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
transformer,
|
| 131 |
+
scheduler,
|
| 132 |
+
id2label: Optional[Dict[int, str]] = None,
|
| 133 |
+
id2label_cn: Optional[Dict[int, str]] = None,
|
| 134 |
+
):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.register_modules(transformer=transformer, scheduler=scheduler)
|
| 137 |
+
|
| 138 |
+
self._id2label = id2label or {}
|
| 139 |
+
self._id2label_cn = id2label_cn or {}
|
| 140 |
+
self.labels = self._build_label2id(self._id2label)
|
| 141 |
+
self.labels_cn = self._build_label2id(self._id2label_cn)
|
| 142 |
+
|
| 143 |
+
def _ensure_labels_loaded(self) -> None:
|
| 144 |
+
if self._id2label or self._id2label_cn:
|
| 145 |
+
return
|
| 146 |
+
loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None))
|
| 147 |
+
if loaded_en:
|
| 148 |
+
self._id2label = loaded_en
|
| 149 |
+
self.labels = self._build_label2id(self._id2label)
|
| 150 |
+
if loaded_cn:
|
| 151 |
+
self._id2label_cn = loaded_cn
|
| 152 |
+
self.labels_cn = self._build_label2id(self._id2label_cn)
|
| 153 |
+
|
| 154 |
+
@staticmethod
|
| 155 |
+
def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]:
|
| 156 |
+
if not variant_path:
|
| 157 |
+
return None
|
| 158 |
+
variant_dir = Path(variant_path).resolve()
|
| 159 |
+
labels_dir = variant_dir.parent / "labels"
|
| 160 |
+
return labels_dir if labels_dir.is_dir() else None
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]:
|
| 164 |
+
filename = "id2label_en.json" if lang == "en" else "id2label_cn.json"
|
| 165 |
+
path = labels_dir / filename
|
| 166 |
+
if not path.exists():
|
| 167 |
+
raise FileNotFoundError(path)
|
| 168 |
+
raw = json.loads(path.read_text(encoding="utf-8"))
|
| 169 |
+
return {int(key): value for key, value in raw.items()}
|
| 170 |
+
|
| 171 |
+
@classmethod
|
| 172 |
+
def _load_labels_for_variant(
|
| 173 |
+
cls,
|
| 174 |
+
variant_path: Optional[str],
|
| 175 |
+
) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]:
|
| 176 |
+
labels_dir = cls._labels_dir_for_variant(variant_path)
|
| 177 |
+
if labels_dir is None:
|
| 178 |
+
return None, None
|
| 179 |
+
try:
|
| 180 |
+
return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn")
|
| 181 |
+
except FileNotFoundError:
|
| 182 |
+
return None, None
|
| 183 |
+
|
| 184 |
+
@staticmethod
|
| 185 |
+
def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
|
| 186 |
+
label2id: Dict[str, int] = {}
|
| 187 |
+
for class_id, value in id2label.items():
|
| 188 |
+
for synonym in value.split(","):
|
| 189 |
+
synonym = synonym.strip()
|
| 190 |
+
if synonym:
|
| 191 |
+
label2id[synonym] = int(class_id)
|
| 192 |
+
return dict(sorted(label2id.items()))
|
| 193 |
+
|
| 194 |
+
@property
|
| 195 |
+
def id2label(self) -> Dict[int, str]:
|
| 196 |
+
"""ImageNet class id to English label string (comma-separated synonyms)."""
|
| 197 |
+
self._ensure_labels_loaded()
|
| 198 |
+
return self._id2label
|
| 199 |
+
|
| 200 |
+
@property
|
| 201 |
+
def id2label_cn(self) -> Dict[int, str]:
|
| 202 |
+
"""ImageNet class id to Chinese label string (comma-separated synonyms)."""
|
| 203 |
+
self._ensure_labels_loaded()
|
| 204 |
+
return self._id2label_cn
|
| 205 |
+
|
| 206 |
+
def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]:
|
| 207 |
+
r"""
|
| 208 |
+
Map ImageNet label strings to class ids.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
label (`str` or `list[str]`):
|
| 212 |
+
One or more label strings. Each string must match a synonym in `id2label` (English)
|
| 213 |
+
or `id2label_cn` (Chinese).
|
| 214 |
+
lang (`str`, *optional*, defaults to `"en"`):
|
| 215 |
+
`"en"` uses English synonyms; `"cn"` uses Chinese synonyms.
|
| 216 |
+
"""
|
| 217 |
+
if lang not in ("en", "cn"):
|
| 218 |
+
raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.")
|
| 219 |
+
|
| 220 |
+
self._ensure_labels_loaded()
|
| 221 |
+
label2id = self.labels if lang == "en" else self.labels_cn
|
| 222 |
+
if not label2id:
|
| 223 |
+
raise ValueError(
|
| 224 |
+
f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder."
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if isinstance(label, str):
|
| 228 |
+
label = [label]
|
| 229 |
+
|
| 230 |
+
missing = [item for item in label if item not in label2id]
|
| 231 |
+
if missing:
|
| 232 |
+
preview = ", ".join(list(label2id.keys())[:8])
|
| 233 |
+
raise ValueError(
|
| 234 |
+
f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..."
|
| 235 |
+
)
|
| 236 |
+
return [label2id[item] for item in label]
|
| 237 |
+
|
| 238 |
+
def _normalize_class_labels(
|
| 239 |
+
self,
|
| 240 |
+
class_labels: Union[int, str, List[Union[int, str]]],
|
| 241 |
+
) -> List[int]:
|
| 242 |
+
if isinstance(class_labels, int):
|
| 243 |
+
return [class_labels]
|
| 244 |
+
|
| 245 |
+
if isinstance(class_labels, str):
|
| 246 |
+
return self.get_label_ids(class_labels)
|
| 247 |
+
|
| 248 |
+
if class_labels and isinstance(class_labels[0], str):
|
| 249 |
+
self._ensure_labels_loaded()
|
| 250 |
+
if all(label in self.labels for label in class_labels):
|
| 251 |
+
return self.get_label_ids(class_labels, lang="en")
|
| 252 |
+
if all(label in self.labels_cn for label in class_labels):
|
| 253 |
+
return self.get_label_ids(class_labels, lang="cn")
|
| 254 |
+
raise ValueError(
|
| 255 |
+
"Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` "
|
| 256 |
+
"or Chinese synonyms from `pipe.labels_cn`."
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
return list(class_labels)
|
| 260 |
+
|
| 261 |
+
def _predict_velocity(
|
| 262 |
+
self,
|
| 263 |
+
z_value: torch.Tensor,
|
| 264 |
+
t: torch.Tensor,
|
| 265 |
+
class_labels: torch.Tensor,
|
| 266 |
+
class_null: torch.Tensor,
|
| 267 |
+
do_classifier_free_guidance: bool,
|
| 268 |
+
guidance_scale: float,
|
| 269 |
+
guidance_interval_min: float,
|
| 270 |
+
guidance_interval_max: float,
|
| 271 |
+
) -> torch.Tensor:
|
| 272 |
+
t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype)
|
| 273 |
+
if do_classifier_free_guidance:
|
| 274 |
+
z_in = torch.cat([z_value, z_value], dim=0)
|
| 275 |
+
labels = torch.cat([class_labels, class_null], dim=0)
|
| 276 |
+
else:
|
| 277 |
+
z_in = z_value
|
| 278 |
+
labels = class_labels
|
| 279 |
+
|
| 280 |
+
t_batch = t.flatten().expand(z_in.shape[0])
|
| 281 |
+
x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample
|
| 282 |
+
v = self.scheduler.velocity_from_prediction(z_in, x_pred, t)
|
| 283 |
+
|
| 284 |
+
if not do_classifier_free_guidance:
|
| 285 |
+
return v
|
| 286 |
+
|
| 287 |
+
v_cond, v_uncond = v.chunk(2, dim=0)
|
| 288 |
+
interval_mask = t < guidance_interval_max
|
| 289 |
+
if guidance_interval_min != 0.0:
|
| 290 |
+
interval_mask = interval_mask & (t > guidance_interval_min)
|
| 291 |
+
scale = torch.where(
|
| 292 |
+
interval_mask,
|
| 293 |
+
torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype),
|
| 294 |
+
torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype),
|
| 295 |
+
)
|
| 296 |
+
return v_uncond + scale * (v_cond - v_uncond)
|
| 297 |
+
|
| 298 |
+
def _run_sampler(
|
| 299 |
+
self,
|
| 300 |
+
latents: torch.Tensor,
|
| 301 |
+
class_labels: torch.Tensor,
|
| 302 |
+
class_null: torch.Tensor,
|
| 303 |
+
num_inference_steps: int,
|
| 304 |
+
do_classifier_free_guidance: bool,
|
| 305 |
+
guidance_scale: float,
|
| 306 |
+
guidance_interval_min: float,
|
| 307 |
+
guidance_interval_max: float,
|
| 308 |
+
sampling_method: str,
|
| 309 |
+
) -> torch.Tensor:
|
| 310 |
+
device = latents.device
|
| 311 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method)
|
| 312 |
+
timesteps = self.scheduler.timesteps
|
| 313 |
+
|
| 314 |
+
for i in self.progress_bar(range(num_inference_steps - 1)):
|
| 315 |
+
t = timesteps[i]
|
| 316 |
+
t_next = timesteps[i + 1]
|
| 317 |
+
v = self._predict_velocity(
|
| 318 |
+
latents,
|
| 319 |
+
t,
|
| 320 |
+
class_labels,
|
| 321 |
+
class_null,
|
| 322 |
+
do_classifier_free_guidance,
|
| 323 |
+
guidance_scale,
|
| 324 |
+
guidance_interval_min,
|
| 325 |
+
guidance_interval_max,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
if sampling_method == "heun":
|
| 329 |
+
latents_euler = latents + (t_next - t) * v
|
| 330 |
+
v_next = self._predict_velocity(
|
| 331 |
+
latents_euler,
|
| 332 |
+
t_next,
|
| 333 |
+
class_labels,
|
| 334 |
+
class_null,
|
| 335 |
+
do_classifier_free_guidance,
|
| 336 |
+
guidance_scale,
|
| 337 |
+
guidance_interval_min,
|
| 338 |
+
guidance_interval_max,
|
| 339 |
+
)
|
| 340 |
+
latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample
|
| 341 |
+
else:
|
| 342 |
+
latents = self.scheduler.step(v, t, latents).prev_sample
|
| 343 |
+
|
| 344 |
+
t = timesteps[-2]
|
| 345 |
+
t_next = timesteps[-1]
|
| 346 |
+
v = self._predict_velocity(
|
| 347 |
+
latents,
|
| 348 |
+
t,
|
| 349 |
+
class_labels,
|
| 350 |
+
class_null,
|
| 351 |
+
do_classifier_free_guidance,
|
| 352 |
+
guidance_scale,
|
| 353 |
+
guidance_interval_min,
|
| 354 |
+
guidance_interval_max,
|
| 355 |
+
)
|
| 356 |
+
return latents + (t_next - t) * v
|
| 357 |
+
|
| 358 |
+
@torch.inference_mode()
|
| 359 |
+
def __call__(
|
| 360 |
+
self,
|
| 361 |
+
class_labels: Union[int, str, List[Union[int, str]]],
|
| 362 |
+
guidance_scale: Optional[float] = None,
|
| 363 |
+
guidance_interval_min: float = 0.1,
|
| 364 |
+
guidance_interval_max: float = 1.0,
|
| 365 |
+
noise_scale: Optional[float] = None,
|
| 366 |
+
t_eps: Optional[float] = None,
|
| 367 |
+
sampling_method: Optional[str] = None,
|
| 368 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 369 |
+
num_inference_steps: int = 50,
|
| 370 |
+
output_type: Optional[str] = "pil",
|
| 371 |
+
return_dict: bool = True,
|
| 372 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
| 373 |
+
r"""
|
| 374 |
+
Generate class-conditional images.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
class_labels (`int`, `str`, `list[int]`, or `list[str]`):
|
| 378 |
+
ImageNet class indices or human-readable label strings (English or Chinese).
|
| 379 |
+
guidance_scale (`float`, *optional*):
|
| 380 |
+
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
|
| 381 |
+
guidance_interval_min (`float`, defaults to `0.1`):
|
| 382 |
+
Lower bound of the CFG interval in flow time `t in [0, 1]`.
|
| 383 |
+
guidance_interval_max (`float`, defaults to `1.0`):
|
| 384 |
+
Upper bound of the CFG interval in flow time.
|
| 385 |
+
noise_scale (`float`, *optional*):
|
| 386 |
+
Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
|
| 387 |
+
t_eps (`float`, *optional*):
|
| 388 |
+
Epsilon clamp for the `1 - t` denominator (scheduler config by default).
|
| 389 |
+
sampling_method (`str`, *optional*):
|
| 390 |
+
`"heun"` or `"euler"`. Defaults to the scheduler config (`heun`).
|
| 391 |
+
generator (`torch.Generator`, *optional*):
|
| 392 |
+
RNG for reproducibility.
|
| 393 |
+
num_inference_steps (`int`, defaults to `50`):
|
| 394 |
+
Number of solver steps (at least 2).
|
| 395 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 396 |
+
`"pil"`, `"np"`, or `"pt"`.
|
| 397 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 398 |
+
Return [`ImagePipelineOutput`] if True.
|
| 399 |
+
"""
|
| 400 |
+
solver = sampling_method or self.scheduler.config.solver
|
| 401 |
+
if solver not in {"heun", "euler"}:
|
| 402 |
+
raise ValueError("sampling_method must be one of: 'heun', 'euler'.")
|
| 403 |
+
if num_inference_steps < 2:
|
| 404 |
+
raise ValueError("num_inference_steps must be >= 2.")
|
| 405 |
+
|
| 406 |
+
if t_eps is not None:
|
| 407 |
+
self.scheduler.register_to_config(t_eps=t_eps)
|
| 408 |
+
|
| 409 |
+
class_label_ids = self._normalize_class_labels(class_labels)
|
| 410 |
+
do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
|
| 411 |
+
|
| 412 |
+
batch_size = len(class_label_ids)
|
| 413 |
+
image_size = int(self.transformer.config.sample_size)
|
| 414 |
+
channels = int(self.transformer.config.in_channels)
|
| 415 |
+
null_class_val = int(self.transformer.config.num_classes)
|
| 416 |
+
|
| 417 |
+
if guidance_scale is None:
|
| 418 |
+
guidance_scale = 1.0
|
| 419 |
+
if noise_scale is None:
|
| 420 |
+
noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0)
|
| 421 |
+
|
| 422 |
+
latents = (
|
| 423 |
+
randn_tensor(
|
| 424 |
+
shape=(batch_size, channels, image_size, image_size),
|
| 425 |
+
generator=generator,
|
| 426 |
+
device=self._execution_device,
|
| 427 |
+
dtype=self.transformer.dtype,
|
| 428 |
+
)
|
| 429 |
+
* noise_scale
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
|
| 433 |
+
class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
|
| 434 |
+
class_null = torch.full_like(class_labels_t, null_class_val)
|
| 435 |
+
|
| 436 |
+
latents = self._run_sampler(
|
| 437 |
+
latents,
|
| 438 |
+
class_labels_t,
|
| 439 |
+
class_null,
|
| 440 |
+
num_inference_steps,
|
| 441 |
+
do_classifier_free_guidance,
|
| 442 |
+
guidance_scale,
|
| 443 |
+
guidance_interval_min,
|
| 444 |
+
guidance_interval_max,
|
| 445 |
+
solver,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
|
| 449 |
+
if output_type == "pt":
|
| 450 |
+
images = images_pt
|
| 451 |
+
elif output_type == "np":
|
| 452 |
+
images = images_pt.permute(0, 2, 3, 1).numpy()
|
| 453 |
+
else:
|
| 454 |
+
images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy())
|
| 455 |
+
|
| 456 |
+
self.maybe_free_model_hooks()
|
| 457 |
+
|
| 458 |
+
if not return_dict:
|
| 459 |
+
return (images,)
|
| 460 |
+
return ImagePipelineOutput(images=images)
|
JiT-L-32/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "JiTScheduler",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"num_train_timesteps": 1000,
|
| 5 |
+
"t_eps": 0.05,
|
| 6 |
+
"solver": "heun"
|
| 7 |
+
}
|
JiT-L-32/scheduler/scheduling_jit.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 21 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
| 22 |
+
from diffusers.utils import BaseOutput
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class JiTSchedulerOutput(BaseOutput):
|
| 27 |
+
"""
|
| 28 |
+
Output class for the JiT scheduler's `step` function.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
prev_sample (`torch.Tensor`):
|
| 32 |
+
Updated sample after one solver step along the JiT flow-time grid.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
prev_sample: torch.Tensor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class JiTScheduler(SchedulerMixin, ConfigMixin):
|
| 39 |
+
"""
|
| 40 |
+
Manual flow-matching scheduler for JiT checkpoints.
|
| 41 |
+
|
| 42 |
+
Uses a linear flow-time grid `t in [0, 1]` (increasing), matching the official JiT
|
| 43 |
+
sampler. Velocity is `v = (x_pred - z) / (1 - t)`; integration is explicit Euler or
|
| 44 |
+
Heun along that grid.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
order = 2
|
| 48 |
+
|
| 49 |
+
@register_to_config
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
num_train_timesteps: int = 1000,
|
| 53 |
+
t_eps: float = 5e-2,
|
| 54 |
+
solver: str = "heun",
|
| 55 |
+
):
|
| 56 |
+
if solver not in {"heun", "euler"}:
|
| 57 |
+
raise ValueError("solver must be one of: 'heun', 'euler'.")
|
| 58 |
+
self.timesteps: Optional[torch.Tensor] = None
|
| 59 |
+
self.sigmas: Optional[List[float]] = None
|
| 60 |
+
self.num_inference_steps: Optional[int] = None
|
| 61 |
+
self._step_index: Optional[int] = None
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def init_noise_sigma(self) -> float:
|
| 65 |
+
return 1.0
|
| 66 |
+
|
| 67 |
+
def set_timesteps(
|
| 68 |
+
self,
|
| 69 |
+
num_inference_steps: int,
|
| 70 |
+
device: Union[str, torch.device, None] = None,
|
| 71 |
+
solver: Optional[str] = None,
|
| 72 |
+
) -> None:
|
| 73 |
+
if num_inference_steps < 2:
|
| 74 |
+
raise ValueError("num_inference_steps must be >= 2.")
|
| 75 |
+
|
| 76 |
+
self.num_inference_steps = num_inference_steps
|
| 77 |
+
self.timesteps = torch.linspace(
|
| 78 |
+
0.0,
|
| 79 |
+
1.0,
|
| 80 |
+
num_inference_steps + 1,
|
| 81 |
+
device=device,
|
| 82 |
+
dtype=torch.float32,
|
| 83 |
+
)
|
| 84 |
+
sigma_grid = torch.linspace(0.0, 1.0, num_inference_steps, device=device, dtype=torch.float32)
|
| 85 |
+
self.sigmas = (1.0 - sigma_grid).tolist()
|
| 86 |
+
self._step_index = 0
|
| 87 |
+
if solver is not None:
|
| 88 |
+
self.register_to_config(solver=solver)
|
| 89 |
+
|
| 90 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
| 91 |
+
del timestep
|
| 92 |
+
return sample
|
| 93 |
+
|
| 94 |
+
def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int:
|
| 95 |
+
if self._step_index is not None:
|
| 96 |
+
return self._step_index
|
| 97 |
+
if self.timesteps is None:
|
| 98 |
+
raise ValueError("Call `set_timesteps` before `step`.")
|
| 99 |
+
if timestep is None:
|
| 100 |
+
return 0
|
| 101 |
+
t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0])
|
| 102 |
+
matches = (self.timesteps - t_value).abs() < 1e-6
|
| 103 |
+
if matches.any():
|
| 104 |
+
return int(matches.nonzero(as_tuple=False)[0].item())
|
| 105 |
+
return 0
|
| 106 |
+
|
| 107 |
+
def step(
|
| 108 |
+
self,
|
| 109 |
+
model_output: torch.Tensor,
|
| 110 |
+
timestep: Union[float, torch.Tensor, None],
|
| 111 |
+
sample: torch.Tensor,
|
| 112 |
+
model_output_next: Optional[torch.Tensor] = None,
|
| 113 |
+
return_dict: bool = True,
|
| 114 |
+
) -> Union[JiTSchedulerOutput, Tuple[torch.Tensor]]:
|
| 115 |
+
"""
|
| 116 |
+
Integrate one step on the linear `t` grid.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
model_output (`torch.Tensor`):
|
| 120 |
+
Velocity `v = (x_pred - z) / (1 - t)` at the current time.
|
| 121 |
+
timestep (`float` or `torch.Tensor`, *optional*):
|
| 122 |
+
Current flow time `t`. When omitted, uses the internal step index.
|
| 123 |
+
sample (`torch.Tensor`):
|
| 124 |
+
Current noisy latent `z`.
|
| 125 |
+
model_output_next (`torch.Tensor`, *optional*):
|
| 126 |
+
Velocity at `t_next` (required for Heun intermediate steps).
|
| 127 |
+
"""
|
| 128 |
+
if self.timesteps is None:
|
| 129 |
+
raise ValueError("Call `set_timesteps` before `step`.")
|
| 130 |
+
|
| 131 |
+
step_index = self._resolve_step_index(timestep)
|
| 132 |
+
if step_index >= len(self.timesteps) - 1:
|
| 133 |
+
raise ValueError("Scheduler has already reached the final timestep.")
|
| 134 |
+
|
| 135 |
+
t = self.timesteps[step_index]
|
| 136 |
+
t_next = self.timesteps[step_index + 1]
|
| 137 |
+
dt = t_next - t
|
| 138 |
+
|
| 139 |
+
if self.config.solver == "heun" and model_output_next is not None:
|
| 140 |
+
prev_sample = sample + dt * 0.5 * (model_output + model_output_next)
|
| 141 |
+
else:
|
| 142 |
+
prev_sample = sample + dt * model_output
|
| 143 |
+
|
| 144 |
+
self._step_index = step_index + 1
|
| 145 |
+
|
| 146 |
+
if not return_dict:
|
| 147 |
+
return (prev_sample,)
|
| 148 |
+
return JiTSchedulerOutput(prev_sample=prev_sample)
|
| 149 |
+
|
| 150 |
+
def velocity_from_prediction(
|
| 151 |
+
self,
|
| 152 |
+
sample: torch.Tensor,
|
| 153 |
+
x_pred: torch.Tensor,
|
| 154 |
+
timestep: Union[float, torch.Tensor],
|
| 155 |
+
) -> torch.Tensor:
|
| 156 |
+
"""Compute JiT velocity `v = (x_pred - z) / (1 - t)` with denominator clamp."""
|
| 157 |
+
t = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype)
|
| 158 |
+
while t.ndim < sample.ndim:
|
| 159 |
+
t = t.unsqueeze(-1)
|
| 160 |
+
denom = (1.0 - t).clamp_min(self.config.t_eps)
|
| 161 |
+
return (x_pred - sample) / denom
|
JiT-L-32/transformer/config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "JiTTransformer2DModel",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"attention_dropout": 0.0,
|
| 5 |
+
"bottleneck_dim": 128,
|
| 6 |
+
"dropout": 0.0,
|
| 7 |
+
"hidden_size": 1024,
|
| 8 |
+
"in_channels": 3,
|
| 9 |
+
"in_context_len": 32,
|
| 10 |
+
"in_context_start": 8,
|
| 11 |
+
"mlp_ratio": 4.0,
|
| 12 |
+
"norm_eps": 1e-06,
|
| 13 |
+
"num_attention_heads": 16,
|
| 14 |
+
"num_classes": 1000,
|
| 15 |
+
"num_layers": 24,
|
| 16 |
+
"patch_size": 32,
|
| 17 |
+
"sample_size": 512
|
| 18 |
+
}
|
JiT-L-32/transformer/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:121d3917ab50ad034295646734eb9b898167f19419dd65d22946f38c7d183266
|
| 3 |
+
size 1847219704
|
JiT-L-32/transformer/jit_transformer_2d.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 24 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 25 |
+
from diffusers.models.normalization import RMSNorm
|
| 26 |
+
from diffusers.utils import logging
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def broadcat(tensors, dim=-1):
|
| 33 |
+
num_tensors = len(tensors)
|
| 34 |
+
shape_lens = {len(t.shape) for t in tensors}
|
| 35 |
+
if len(shape_lens) != 1:
|
| 36 |
+
raise ValueError("tensors must all have the same number of dimensions")
|
| 37 |
+
shape_len = list(shape_lens)[0]
|
| 38 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
| 39 |
+
dims = list(zip(*(list(t.shape) for t in tensors)))
|
| 40 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
| 41 |
+
|
| 42 |
+
if not all(len(set(t[1])) <= 2 for t in expandable_dims):
|
| 43 |
+
raise ValueError("invalid dimensions for broadcastable concatenation")
|
| 44 |
+
|
| 45 |
+
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
| 46 |
+
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
| 47 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
| 48 |
+
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
| 49 |
+
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
| 50 |
+
return torch.cat(tensors, dim=dim)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def rotate_half(x):
|
| 54 |
+
x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2)
|
| 55 |
+
x1, x2 = x.unbind(dim=-1)
|
| 56 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 57 |
+
return x.view(*x.shape[:-2], -1)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class JiTRotaryEmbedding(nn.Module):
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
dim,
|
| 64 |
+
pt_seq_len=16,
|
| 65 |
+
ft_seq_len=None,
|
| 66 |
+
custom_freqs=None,
|
| 67 |
+
theta=10000,
|
| 68 |
+
num_cls_token=0,
|
| 69 |
+
):
|
| 70 |
+
super().__init__()
|
| 71 |
+
if custom_freqs is not None:
|
| 72 |
+
freqs = custom_freqs
|
| 73 |
+
else:
|
| 74 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 75 |
+
|
| 76 |
+
if ft_seq_len is None:
|
| 77 |
+
ft_seq_len = pt_seq_len
|
| 78 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
| 79 |
+
|
| 80 |
+
freqs = torch.einsum("..., f -> ... f", t, freqs)
|
| 81 |
+
freqs = freqs.repeat_interleave(2, dim=-1)
|
| 82 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
|
| 83 |
+
|
| 84 |
+
if num_cls_token > 0:
|
| 85 |
+
freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D]
|
| 86 |
+
cos_img = freqs_flat.cos()
|
| 87 |
+
sin_img = freqs_flat.sin()
|
| 88 |
+
|
| 89 |
+
# prepend in-context cls token
|
| 90 |
+
_, D = cos_img.shape
|
| 91 |
+
cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype)
|
| 92 |
+
sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype)
|
| 93 |
+
|
| 94 |
+
self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False)
|
| 95 |
+
self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False)
|
| 96 |
+
else:
|
| 97 |
+
self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False)
|
| 98 |
+
self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False)
|
| 99 |
+
|
| 100 |
+
def forward(self, t):
|
| 101 |
+
# Applied on (batch, seq_len, heads, head_dim) tensors from attention.
|
| 102 |
+
seq_len = t.shape[1]
|
| 103 |
+
freqs_cos = self.freqs_cos[:seq_len].to(t.dtype)
|
| 104 |
+
freqs_sin = self.freqs_sin[:seq_len].to(t.dtype)
|
| 105 |
+
|
| 106 |
+
return t * freqs_cos[:, None, :] + rotate_half(t) * freqs_sin[:, None, :]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def modulate(x, shift, scale):
|
| 110 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class JiTPatchEmbed(nn.Module):
|
| 114 |
+
"""Image to Patch Embedding with Bottleneck"""
|
| 115 |
+
|
| 116 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True):
|
| 117 |
+
super().__init__()
|
| 118 |
+
img_size = (img_size, img_size)
|
| 119 |
+
patch_size = (patch_size, patch_size)
|
| 120 |
+
self.img_size = img_size
|
| 121 |
+
self.patch_size = patch_size
|
| 122 |
+
self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
| 123 |
+
|
| 124 |
+
self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False)
|
| 125 |
+
self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias)
|
| 126 |
+
|
| 127 |
+
def forward(self, x):
|
| 128 |
+
x = self.proj2(self.proj1(x)).flatten(2).transpose(1, 2)
|
| 129 |
+
return x
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class JiTTimestepEmbedder(nn.Module):
|
| 133 |
+
"""
|
| 134 |
+
Embeds scalar timesteps into vector representations.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.mlp = nn.Sequential(
|
| 140 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 141 |
+
nn.SiLU(),
|
| 142 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 143 |
+
)
|
| 144 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 145 |
+
|
| 146 |
+
@staticmethod
|
| 147 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 148 |
+
"""
|
| 149 |
+
Create sinusoidal timestep embeddings.
|
| 150 |
+
"""
|
| 151 |
+
half = dim // 2
|
| 152 |
+
freqs = torch.exp(
|
| 153 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 154 |
+
).to(device=t.device)
|
| 155 |
+
args = t[:, None].float() * freqs[None]
|
| 156 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 157 |
+
if dim % 2:
|
| 158 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 159 |
+
return embedding
|
| 160 |
+
|
| 161 |
+
def forward(self, t, dtype=None):
|
| 162 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 163 |
+
if dtype is not None:
|
| 164 |
+
t_freq = t_freq.to(dtype=dtype)
|
| 165 |
+
t_emb = self.mlp(t_freq)
|
| 166 |
+
return t_emb
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class JiTLabelEmbedder(nn.Module):
|
| 170 |
+
"""
|
| 171 |
+
Embeds class labels into vector representations.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(self, num_classes, hidden_size):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.embedding_table = nn.Embedding(num_classes + 1, hidden_size)
|
| 177 |
+
self.num_classes = num_classes
|
| 178 |
+
|
| 179 |
+
def forward(self, labels):
|
| 180 |
+
embeddings = self.embedding_table(labels)
|
| 181 |
+
return embeddings
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class JiTAttention(nn.Module):
|
| 185 |
+
def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.num_heads = num_heads
|
| 188 |
+
head_dim = dim // num_heads
|
| 189 |
+
|
| 190 |
+
self.q_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
|
| 191 |
+
self.k_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
|
| 192 |
+
|
| 193 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 194 |
+
self.attn_drop = attn_drop
|
| 195 |
+
self.proj = nn.Linear(dim, dim)
|
| 196 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 197 |
+
|
| 198 |
+
def forward(self, x, rope=None):
|
| 199 |
+
B, N, C = x.shape
|
| 200 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 201 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 202 |
+
|
| 203 |
+
q = self.q_norm(q)
|
| 204 |
+
k = self.k_norm(k)
|
| 205 |
+
|
| 206 |
+
if rope is not None:
|
| 207 |
+
q = q.transpose(1, 2)
|
| 208 |
+
k = k.transpose(1, 2)
|
| 209 |
+
q = rope(q)
|
| 210 |
+
k = rope(k)
|
| 211 |
+
q = q.transpose(1, 2)
|
| 212 |
+
k = k.transpose(1, 2)
|
| 213 |
+
|
| 214 |
+
dropout_p = self.attn_drop if self.training else 0.0
|
| 215 |
+
x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
| 216 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 217 |
+
x = self.proj(x)
|
| 218 |
+
x = self.proj_drop(x)
|
| 219 |
+
return x
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class JiTSwiGLUFFN(nn.Module):
|
| 223 |
+
def __init__(self, dim: int, hidden_dim: int, drop=0.0, bias=True) -> None:
|
| 224 |
+
super().__init__()
|
| 225 |
+
hidden_dim = int(hidden_dim * 2 / 3)
|
| 226 |
+
self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias)
|
| 227 |
+
self.w3 = nn.Linear(hidden_dim, dim, bias=bias)
|
| 228 |
+
self.ffn_dropout = nn.Dropout(drop)
|
| 229 |
+
|
| 230 |
+
def forward(self, x):
|
| 231 |
+
x12 = self.w12(x)
|
| 232 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 233 |
+
hidden = F.silu(x1) * x2
|
| 234 |
+
return self.w3(self.ffn_dropout(hidden))
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class JiTBlock(nn.Module):
|
| 238 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, eps=1e-6):
|
| 239 |
+
super().__init__()
|
| 240 |
+
self.norm1 = RMSNorm(hidden_size, eps=eps)
|
| 241 |
+
self.attn = JiTAttention(
|
| 242 |
+
hidden_size,
|
| 243 |
+
num_heads=num_heads,
|
| 244 |
+
qkv_bias=True,
|
| 245 |
+
qk_norm=True,
|
| 246 |
+
attn_drop=attn_drop,
|
| 247 |
+
proj_drop=proj_drop,
|
| 248 |
+
eps=eps,
|
| 249 |
+
)
|
| 250 |
+
self.norm2 = RMSNorm(hidden_size, eps=eps)
|
| 251 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 252 |
+
self.mlp = JiTSwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop)
|
| 253 |
+
|
| 254 |
+
self.act = nn.SiLU()
|
| 255 |
+
self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 256 |
+
|
| 257 |
+
def forward(self, x, c, feat_rope=None):
|
| 258 |
+
# Apply activation
|
| 259 |
+
c = self.act(c)
|
| 260 |
+
|
| 261 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
| 262 |
+
|
| 263 |
+
# Attention block
|
| 264 |
+
norm_x = self.norm1(x)
|
| 265 |
+
modulated_x = modulate(norm_x, shift_msa, scale_msa)
|
| 266 |
+
attn_out = self.attn(modulated_x, rope=feat_rope)
|
| 267 |
+
x = x + gate_msa.unsqueeze(1) * attn_out
|
| 268 |
+
|
| 269 |
+
# MLP block
|
| 270 |
+
norm_x = self.norm2(x)
|
| 271 |
+
modulated_x = modulate(norm_x, shift_mlp, scale_mlp)
|
| 272 |
+
mlp_out = self.mlp(modulated_x)
|
| 273 |
+
x = x + gate_mlp.unsqueeze(1) * mlp_out
|
| 274 |
+
|
| 275 |
+
return x
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
| 279 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 280 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 281 |
+
grid = np.meshgrid(grid_w, grid_h)
|
| 282 |
+
grid = np.stack(grid, axis=0)
|
| 283 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 284 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 285 |
+
if cls_token and extra_tokens > 0:
|
| 286 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 287 |
+
return pos_embed
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 291 |
+
if embed_dim % 2 != 0:
|
| 292 |
+
raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
|
| 293 |
+
|
| 294 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
|
| 295 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
|
| 296 |
+
emb = np.concatenate([emb_h, emb_w], axis=1)
|
| 297 |
+
return emb
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 301 |
+
if embed_dim % 2 != 0:
|
| 302 |
+
raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}")
|
| 303 |
+
|
| 304 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 305 |
+
omega /= embed_dim / 2.0
|
| 306 |
+
omega = 1.0 / 10000**omega
|
| 307 |
+
|
| 308 |
+
pos = pos.reshape(-1)
|
| 309 |
+
out = np.einsum("m,d->md", pos, omega)
|
| 310 |
+
|
| 311 |
+
emb_sin = np.sin(out)
|
| 312 |
+
emb_cos = np.cos(out)
|
| 313 |
+
|
| 314 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1)
|
| 315 |
+
return emb
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class JiTTransformer2DModel(ModelMixin, ConfigMixin):
|
| 319 |
+
r"""
|
| 320 |
+
A 2D Transformer for pixel-space class-conditional generation with JiT
|
| 321 |
+
([Back to Basics: Let Denoising Generative Models Denoise](https://arxiv.org/abs/2511.13720)).
|
| 322 |
+
|
| 323 |
+
Parameters:
|
| 324 |
+
sample_size (`int`, defaults to `256`):
|
| 325 |
+
Input image resolution (height and width).
|
| 326 |
+
patch_size (`int`, defaults to `16`):
|
| 327 |
+
Patch size for the bottleneck patch embedder.
|
| 328 |
+
in_channels (`int`, defaults to `3`):
|
| 329 |
+
Number of input image channels.
|
| 330 |
+
hidden_size (`int`, defaults to `768`):
|
| 331 |
+
Transformer hidden dimension.
|
| 332 |
+
num_layers (`int`, defaults to `12`):
|
| 333 |
+
Number of JiT transformer blocks.
|
| 334 |
+
num_attention_heads (`int`, defaults to `12`):
|
| 335 |
+
Number of attention heads per block.
|
| 336 |
+
mlp_ratio (`float`, defaults to `4.0`):
|
| 337 |
+
MLP hidden dimension multiplier.
|
| 338 |
+
attention_dropout (`float`, defaults to `0.0`):
|
| 339 |
+
Attention dropout in the middle quarter of blocks.
|
| 340 |
+
dropout (`float`, defaults to `0.0`):
|
| 341 |
+
Projection dropout in the middle quarter of blocks.
|
| 342 |
+
num_classes (`int`, defaults to `1000`):
|
| 343 |
+
Number of class labels (null label uses index `num_classes` for CFG).
|
| 344 |
+
bottleneck_dim (`int`, defaults to `128`):
|
| 345 |
+
PCA bottleneck dimension in the patch embedder.
|
| 346 |
+
in_context_len (`int`, defaults to `32`):
|
| 347 |
+
Number of in-context class tokens prepended mid-network.
|
| 348 |
+
in_context_start (`int`, defaults to `4`):
|
| 349 |
+
Block index at which in-context tokens are inserted.
|
| 350 |
+
norm_eps (`float`, defaults to `1e-6`):
|
| 351 |
+
Epsilon for RMSNorm layers.
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
_supports_gradient_checkpointing = True
|
| 355 |
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
| 356 |
+
|
| 357 |
+
@register_to_config
|
| 358 |
+
def __init__(
|
| 359 |
+
self,
|
| 360 |
+
sample_size: int = 256,
|
| 361 |
+
patch_size: int = 16,
|
| 362 |
+
in_channels: int = 3,
|
| 363 |
+
hidden_size: int = 768,
|
| 364 |
+
num_layers: int = 12,
|
| 365 |
+
num_attention_heads: int = 12,
|
| 366 |
+
mlp_ratio: float = 4.0,
|
| 367 |
+
attention_dropout: float = 0.0,
|
| 368 |
+
dropout: float = 0.0,
|
| 369 |
+
num_classes: int = 1000,
|
| 370 |
+
bottleneck_dim: int = 128,
|
| 371 |
+
in_context_len: int = 32,
|
| 372 |
+
in_context_start: int = 4,
|
| 373 |
+
norm_eps: float = 1e-6,
|
| 374 |
+
):
|
| 375 |
+
super().__init__()
|
| 376 |
+
self.sample_size = sample_size
|
| 377 |
+
self.patch_size = patch_size
|
| 378 |
+
self.in_channels = in_channels
|
| 379 |
+
self.out_channels = in_channels
|
| 380 |
+
self.hidden_size = hidden_size
|
| 381 |
+
self.num_layers = num_layers
|
| 382 |
+
self.num_attention_heads = num_attention_heads
|
| 383 |
+
self.in_context_len = in_context_len
|
| 384 |
+
self.in_context_start = in_context_start
|
| 385 |
+
self.norm_eps = norm_eps
|
| 386 |
+
self.gradient_checkpointing = False
|
| 387 |
+
|
| 388 |
+
# Time and Class Embedding
|
| 389 |
+
self.t_embedder = JiTTimestepEmbedder(hidden_size)
|
| 390 |
+
self.y_embedder = JiTLabelEmbedder(num_classes, hidden_size)
|
| 391 |
+
|
| 392 |
+
# Patch Embedding
|
| 393 |
+
self.x_embedder = JiTPatchEmbed(
|
| 394 |
+
img_size=sample_size,
|
| 395 |
+
patch_size=patch_size,
|
| 396 |
+
in_chans=in_channels,
|
| 397 |
+
pca_dim=bottleneck_dim,
|
| 398 |
+
embed_dim=hidden_size,
|
| 399 |
+
bias=True,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# Positional Embedding (Fixed Sin-Cos)
|
| 403 |
+
num_patches = self.x_embedder.num_patches
|
| 404 |
+
pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches**0.5))
|
| 405 |
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
|
| 406 |
+
|
| 407 |
+
# In-context Embedding
|
| 408 |
+
if self.in_context_len > 0:
|
| 409 |
+
self.in_context_posemb = nn.Parameter(torch.zeros(1, self.in_context_len, hidden_size))
|
| 410 |
+
|
| 411 |
+
# RoPE
|
| 412 |
+
half_head_dim = hidden_size // num_attention_heads // 2
|
| 413 |
+
hw_seq_len = sample_size // patch_size
|
| 414 |
+
self.feat_rope = JiTRotaryEmbedding(dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=0)
|
| 415 |
+
self.feat_rope_incontext = JiTRotaryEmbedding(
|
| 416 |
+
dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=self.in_context_len
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
# Blocks
|
| 420 |
+
self.blocks = nn.ModuleList(
|
| 421 |
+
[
|
| 422 |
+
JiTBlock(
|
| 423 |
+
hidden_size,
|
| 424 |
+
num_attention_heads,
|
| 425 |
+
mlp_ratio=mlp_ratio,
|
| 426 |
+
attn_drop=attention_dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
|
| 427 |
+
proj_drop=dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0,
|
| 428 |
+
eps=norm_eps,
|
| 429 |
+
)
|
| 430 |
+
for i in range(num_layers)
|
| 431 |
+
]
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# Final Layer
|
| 435 |
+
self.norm_final = RMSNorm(hidden_size, eps=norm_eps)
|
| 436 |
+
self.linear_final = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
|
| 437 |
+
self.act_final = nn.SiLU()
|
| 438 |
+
self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 439 |
+
|
| 440 |
+
def forward(
|
| 441 |
+
self,
|
| 442 |
+
hidden_states: torch.Tensor,
|
| 443 |
+
timestep: torch.LongTensor,
|
| 444 |
+
class_labels: torch.LongTensor,
|
| 445 |
+
return_dict: bool = True,
|
| 446 |
+
):
|
| 447 |
+
|
| 448 |
+
t_emb = self.t_embedder(timestep, dtype=hidden_states.dtype)
|
| 449 |
+
y_emb = self.y_embedder(class_labels)
|
| 450 |
+
|
| 451 |
+
# Ensure embeddings match hidden_states dtype
|
| 452 |
+
y_emb = y_emb.to(dtype=hidden_states.dtype)
|
| 453 |
+
|
| 454 |
+
c = t_emb + y_emb
|
| 455 |
+
|
| 456 |
+
# Patch Embed
|
| 457 |
+
x = self.x_embedder(hidden_states)
|
| 458 |
+
x = x + self.pos_embed.to(x.dtype)
|
| 459 |
+
|
| 460 |
+
# Blocks
|
| 461 |
+
for i, block in enumerate(self.blocks):
|
| 462 |
+
if self.in_context_len > 0 and i == self.in_context_start:
|
| 463 |
+
in_context_tokens = y_emb.unsqueeze(1).repeat(1, self.in_context_len, 1)
|
| 464 |
+
in_context_tokens = in_context_tokens + self.in_context_posemb.to(in_context_tokens.dtype)
|
| 465 |
+
x = torch.cat([in_context_tokens, x], dim=1)
|
| 466 |
+
|
| 467 |
+
rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext
|
| 468 |
+
|
| 469 |
+
if self.training and self.gradient_checkpointing:
|
| 470 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 471 |
+
block,
|
| 472 |
+
x,
|
| 473 |
+
c,
|
| 474 |
+
rope,
|
| 475 |
+
use_reentrant=False,
|
| 476 |
+
)
|
| 477 |
+
else:
|
| 478 |
+
x = block(x, c, feat_rope=rope)
|
| 479 |
+
|
| 480 |
+
# Slice off in-context tokens
|
| 481 |
+
if self.in_context_len > 0:
|
| 482 |
+
x = x[:, self.in_context_len :]
|
| 483 |
+
|
| 484 |
+
# Final Layer
|
| 485 |
+
c = self.act_final(c)
|
| 486 |
+
shift, scale = self.adaLN_modulation_final(c).chunk(2, dim=1)
|
| 487 |
+
|
| 488 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 489 |
+
x = self.linear_final(x)
|
| 490 |
+
|
| 491 |
+
# Unpatchify
|
| 492 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 493 |
+
x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels))
|
| 494 |
+
x = torch.einsum("nhwpqc->nchpwq", x)
|
| 495 |
+
output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size))
|
| 496 |
+
|
| 497 |
+
if not return_dict:
|
| 498 |
+
return (output,)
|
| 499 |
+
|
| 500 |
+
return Transformer2DModelOutput(sample=output)
|
README.md
CHANGED
|
@@ -14,77 +14,67 @@ language:
|
|
| 14 |
- en
|
| 15 |
---
|
| 16 |
|
| 17 |
-
# JiT-
|
| 18 |
|
| 19 |
-
|
| 20 |
|
| 21 |
-
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|---|---|---|---|---:|---|---:|---:|
|
| 27 |
-
| JiT-B/16 | `./JiT-B-16` | 256x256 | ImageNet 256x256 | 3.0 | `[0.1, 1.0]` | 1.0 | 3.66 |
|
| 28 |
-
| JiT-L/16 | `./JiT-L-16` | 256x256 | ImageNet 256x256 | 2.4 | `[0.1, 1.0]` | 1.0 | 2.36 |
|
| 29 |
-
| JiT-H/16 | `./JiT-H-16` | 256x256 | ImageNet 256x256 | 2.2 | `[0.1, 1.0]` | 1.0 | 1.86 |
|
| 30 |
-
| JiT-B/32 | `./JiT-B-32` | 512x512 | ImageNet 512x512 | 3.0 | `[0.1, 1.0]` | 2.0 | 4.02 |
|
| 31 |
-
| JiT-L/32 | `./JiT-L-32` | 512x512 | ImageNet 512x512 | 2.5 | `[0.1, 1.0]` | 2.0 | 2.53 |
|
| 32 |
-
| JiT-H/32 | `./JiT-H-32` | 512x512 | ImageNet 512x512 | 2.3 | `[0.1, 1.0]` | 2.0 | 1.94 |
|
| 33 |
|
| 34 |
-
|
| 35 |
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
```python
|
| 43 |
-
from
|
| 44 |
-
import sys
|
| 45 |
import torch
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
| 52 |
-
pipe
|
| 53 |
-
pipe.
|
| 54 |
-
pipe.transformer.eval()
|
| 55 |
|
| 56 |
-
generator = torch.Generator(device=
|
| 57 |
-
|
| 58 |
-
class_labels=
|
| 59 |
num_inference_steps=50,
|
| 60 |
guidance_scale=2.3,
|
| 61 |
-
guidance_interval_min=0.1,
|
| 62 |
-
guidance_interval_max=1.0,
|
| 63 |
-
noise_scale=2.0,
|
| 64 |
-
t_eps=5e-2,
|
| 65 |
sampling_method="heun",
|
| 66 |
generator=generator,
|
| 67 |
-
|
| 68 |
-
)
|
| 69 |
-
image = output.images[0]
|
| 70 |
-
output_path = Path("./demo_images/jit_h32_test_inference.png")
|
| 71 |
-
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 72 |
-
image.save(output_path)
|
| 73 |
-
print(f"Saved image to: {output_path}")
|
| 74 |
```
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
Run these from this repository root (`models/BiliSakura/JiT-diffusers`).
|
| 79 |
-
|
| 80 |
-
```bash
|
| 81 |
-
# 256x256 checkpoints
|
| 82 |
-
python run_jit_diffusers_inference.py --model_path ./JiT-B-16 --output_path ./demo_images/jit_b16_test_inference.png --class_label 207 --seed 42 --steps 50 --cfg 3.0 --interval_min 0.1 --interval_max 1.0 --noise_scale 1.0 --t_eps 5e-2 --solver heun
|
| 83 |
-
python run_jit_diffusers_inference.py --model_path ./JiT-L-16 --output_path ./demo_images/jit_l16_test_inference.png --class_label 207 --seed 42 --steps 50 --cfg 2.4 --interval_min 0.1 --interval_max 1.0 --noise_scale 1.0 --t_eps 5e-2 --solver heun
|
| 84 |
-
python run_jit_diffusers_inference.py --model_path ./JiT-H-16 --output_path ./demo_images/jit_h16_test_inference.png --class_label 207 --seed 42 --steps 50 --cfg 2.2 --interval_min 0.1 --interval_max 1.0 --noise_scale 1.0 --t_eps 5e-2 --solver heun
|
| 85 |
-
|
| 86 |
-
# 512x512 checkpoints
|
| 87 |
-
python run_jit_diffusers_inference.py --model_path ./JiT-B-32 --output_path ./demo_images/jit_b32_test_inference.png --class_label 207 --seed 42 --steps 50 --cfg 3.0 --interval_min 0.1 --interval_max 1.0 --noise_scale 2.0 --t_eps 5e-2 --solver heun
|
| 88 |
-
python run_jit_diffusers_inference.py --model_path ./JiT-L-32 --output_path ./demo_images/jit_l32_test_inference.png --class_label 207 --seed 42 --steps 50 --cfg 2.5 --interval_min 0.1 --interval_max 1.0 --noise_scale 2.0 --t_eps 5e-2 --solver heun
|
| 89 |
-
python run_jit_diffusers_inference.py --model_path ./JiT-H-32 --output_path ./demo_images/jit_h32_test_inference.png --class_label 207 --seed 42 --steps 50 --cfg 2.3 --interval_min 0.1 --interval_max 1.0 --noise_scale 2.0 --t_eps 5e-2 --solver heun
|
| 90 |
-
```
|
|
|
|
| 14 |
- en
|
| 15 |
---
|
| 16 |
|
| 17 |
+
# JiT-diffusers
|
| 18 |
|
| 19 |
+
Native diffusers implementation of **JiT** (Just image Transformer). Each variant folder is self-contained:
|
| 20 |
|
| 21 |
+
- `pipeline.py` — `JiTPipeline`
|
| 22 |
+
- `scheduler/scheduling_jit.py` — `JiTScheduler` (linear `t in [0, 1]`, Heun/Euler)
|
| 23 |
+
- `transformer/jit_transformer_2d.py` — `JiTTransformer2DModel`
|
| 24 |
|
| 25 |
+
Shared ImageNet-1k labels live in [`labels/`](labels/) at the repo root (not duplicated per variant).
|
| 26 |
|
| 27 |
+
No separate `jit_diffusers` package; only PyPI `diffusers` plus local custom code in the variant directory.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
## Available checkpoints
|
| 30 |
|
| 31 |
+
| Checkpoint | Path | Resolution | Recommended CFG |
|
| 32 |
+
|---|---|---|---|
|
| 33 |
+
| JiT-B/16 | `./JiT-B-16` | 256×256 | 3.0 |
|
| 34 |
+
| JiT-L/16 | `./JiT-L-16` | 256×256 | 2.4 |
|
| 35 |
+
| JiT-H/16 | `./JiT-H-16` | 256×256 | 2.2 |
|
| 36 |
+
| JiT-B/32 | `./JiT-B-32` | 512×512 | 3.0 |
|
| 37 |
+
| JiT-L/32 | `./JiT-L-32` | 512×512 | 2.5 |
|
| 38 |
+
| JiT-H/32 | `./JiT-H-32` | 512×512 | 2.3 |
|
| 39 |
|
| 40 |
+
## ImageNet class labels
|
| 41 |
|
| 42 |
+
| File | Direction | Format |
|
| 43 |
+
|---|---|---|
|
| 44 |
+
| `labels/id2label_en.json` | id → English | comma-separated synonyms, e.g. `"207": "golden retriever"` |
|
| 45 |
+
| `labels/id2label_cn.json` | id → Chinese | comma-separated synonyms, e.g. `"207": "金毛猎犬"` |
|
| 46 |
+
|
| 47 |
+
- `pipe.id2label` / `pipe.id2label_cn` — inspect id → label correspondence
|
| 48 |
+
- `pipe.labels` / `pipe.labels_cn` — reverse maps (synonym → id), sorted for browsing
|
| 49 |
+
- `pipe.get_label_ids("golden retriever")` or `pipe.get_label_ids("金毛猎犬", lang="cn")`
|
| 50 |
+
- `pipe(class_labels="golden retriever", ...)` — string labels resolved automatically
|
| 51 |
+
|
| 52 |
+
## Inference
|
| 53 |
|
| 54 |
```python
|
| 55 |
+
from diffusers import DiffusionPipeline
|
|
|
|
| 56 |
import torch
|
| 57 |
|
| 58 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 59 |
+
"./JiT-H-32",
|
| 60 |
+
trust_remote_code=True,
|
| 61 |
+
)
|
| 62 |
+
pipe.to("cuda")
|
| 63 |
+
pipe.transformer.to(dtype=torch.bfloat16)
|
| 64 |
|
| 65 |
+
# Numeric or human-readable labels
|
| 66 |
+
print(pipe.id2label[207])
|
| 67 |
+
print(pipe.get_label_ids("golden retriever"))
|
|
|
|
| 68 |
|
| 69 |
+
generator = torch.Generator(device="cuda").manual_seed(42)
|
| 70 |
+
images = pipe(
|
| 71 |
+
class_labels="golden retriever",
|
| 72 |
num_inference_steps=50,
|
| 73 |
guidance_scale=2.3,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
sampling_method="heun",
|
| 75 |
generator=generator,
|
| 76 |
+
).images
|
| 77 |
+
images[0].save("output.png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
```
|
| 79 |
|
| 80 |
+
Load a **variant subfolder** (e.g. `./JiT-H-32`), not the repo root.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
demo_images/jit_h32_final_test.png
ADDED
|
Git LFS Details
|
demo_images/jit_h32_test_inference.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
labels/__pycache__/imagenet_labels.cpython-312.pyc
ADDED
|
Binary file (3.24 kB). View file
|
|
|
labels/id2label_cn.json
ADDED
|
@@ -0,0 +1,1002 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"0": "丁鲷",
|
| 3 |
+
"1": "金鱼",
|
| 4 |
+
"2": "大白鲨",
|
| 5 |
+
"3": "虎鲨",
|
| 6 |
+
"4": "锤头鲨",
|
| 7 |
+
"5": "电鳐",
|
| 8 |
+
"6": "黄貂鱼",
|
| 9 |
+
"7": "公鸡",
|
| 10 |
+
"8": "母鸡",
|
| 11 |
+
"9": "鸵鸟",
|
| 12 |
+
"10": "燕雀",
|
| 13 |
+
"11": "金翅雀",
|
| 14 |
+
"12": "家朱雀",
|
| 15 |
+
"13": "灯芯草雀",
|
| 16 |
+
"14": "靛蓝雀,靛蓝鸟",
|
| 17 |
+
"15": "蓝鹀",
|
| 18 |
+
"16": "夜莺",
|
| 19 |
+
"17": "松鸦",
|
| 20 |
+
"18": "喜鹊",
|
| 21 |
+
"19": "山雀",
|
| 22 |
+
"20": "河鸟",
|
| 23 |
+
"21": "鸢(猛禽)",
|
| 24 |
+
"22": "秃头鹰",
|
| 25 |
+
"23": "秃鹫",
|
| 26 |
+
"24": "大灰猫头鹰",
|
| 27 |
+
"25": "欧洲火蝾螈",
|
| 28 |
+
"26": "普通蝾螈",
|
| 29 |
+
"27": "水蜥",
|
| 30 |
+
"28": "斑点蝾螈",
|
| 31 |
+
"29": "蝾螈,泥狗",
|
| 32 |
+
"30": "牛蛙",
|
| 33 |
+
"31": "树蛙",
|
| 34 |
+
"32": "尾蛙,铃蟾蜍,肋蟾蜍,尾蟾蜍",
|
| 35 |
+
"33": "红海龟",
|
| 36 |
+
"34": "皮革龟",
|
| 37 |
+
"35": "泥龟",
|
| 38 |
+
"36": "淡水龟",
|
| 39 |
+
"37": "箱龟",
|
| 40 |
+
"38": "带状壁虎",
|
| 41 |
+
"39": "普通鬣蜥",
|
| 42 |
+
"40": "美国变色龙",
|
| 43 |
+
"41": "鞭尾蜥蜴",
|
| 44 |
+
"42": "飞龙科蜥蜴",
|
| 45 |
+
"43": "褶边蜥蜴",
|
| 46 |
+
"44": "鳄鱼蜥蜴",
|
| 47 |
+
"45": "毒蜥",
|
| 48 |
+
"46": "绿蜥蜴",
|
| 49 |
+
"47": "非洲变色龙",
|
| 50 |
+
"48": "科莫多蜥蜴",
|
| 51 |
+
"49": "非洲鳄,尼罗河鳄鱼",
|
| 52 |
+
"50": "美国鳄鱼,鳄鱼",
|
| 53 |
+
"51": "三角龙",
|
| 54 |
+
"52": "雷蛇,蠕虫蛇",
|
| 55 |
+
"53": "环蛇,环颈蛇",
|
| 56 |
+
"54": "希腊蛇",
|
| 57 |
+
"55": "绿蛇,草蛇",
|
| 58 |
+
"56": "国王蛇",
|
| 59 |
+
"57": "袜带蛇,草蛇",
|
| 60 |
+
"58": "水蛇",
|
| 61 |
+
"59": "藤蛇",
|
| 62 |
+
"60": "夜蛇",
|
| 63 |
+
"61": "大蟒蛇",
|
| 64 |
+
"62": "岩石蟒蛇,岩蛇,蟒蛇",
|
| 65 |
+
"63": "印度眼镜蛇",
|
| 66 |
+
"64": "绿曼巴",
|
| 67 |
+
"65": "海蛇",
|
| 68 |
+
"66": "角腹蛇",
|
| 69 |
+
"67": "菱纹响尾蛇",
|
| 70 |
+
"68": "角响尾蛇",
|
| 71 |
+
"69": "三叶虫",
|
| 72 |
+
"70": "盲蜘蛛",
|
| 73 |
+
"71": "蝎子",
|
| 74 |
+
"72": "黑金花园蜘蛛",
|
| 75 |
+
"73": "谷仓蜘蛛",
|
| 76 |
+
"74": "花园蜘蛛",
|
| 77 |
+
"75": "黑寡妇蜘蛛",
|
| 78 |
+
"76": "狼蛛",
|
| 79 |
+
"77": "狼蜘蛛,狩猎蜘蛛",
|
| 80 |
+
"78": "壁虱",
|
| 81 |
+
"79": "蜈蚣",
|
| 82 |
+
"80": "黑松鸡",
|
| 83 |
+
"81": "松鸡,雷鸟",
|
| 84 |
+
"82": "披肩鸡,披肩榛鸡",
|
| 85 |
+
"83": "草原鸡,草原松鸡",
|
| 86 |
+
"84": "孔雀",
|
| 87 |
+
"85": "鹌鹑",
|
| 88 |
+
"86": "鹧鸪",
|
| 89 |
+
"87": "非洲灰鹦鹉",
|
| 90 |
+
"88": "金刚鹦鹉",
|
| 91 |
+
"89": "硫冠鹦鹉",
|
| 92 |
+
"90": "短尾鹦鹉",
|
| 93 |
+
"91": "褐翅鸦鹃",
|
| 94 |
+
"92": "蜜蜂",
|
| 95 |
+
"93": "犀鸟",
|
| 96 |
+
"94": "蜂鸟",
|
| 97 |
+
"95": "鹟䴕",
|
| 98 |
+
"96": "犀鸟",
|
| 99 |
+
"97": "野鸭",
|
| 100 |
+
"98": "红胸秋沙鸭",
|
| 101 |
+
"99": "鹅",
|
| 102 |
+
"100": "黑天鹅",
|
| 103 |
+
"101": "大象",
|
| 104 |
+
"102": "针鼹鼠",
|
| 105 |
+
"103": "鸭嘴兽",
|
| 106 |
+
"104": "沙袋鼠",
|
| 107 |
+
"105": "考拉,考拉熊",
|
| 108 |
+
"106": "袋熊",
|
| 109 |
+
"107": "水母",
|
| 110 |
+
"108": "海葵",
|
| 111 |
+
"109": "脑珊瑚",
|
| 112 |
+
"110": "扁形虫扁虫",
|
| 113 |
+
"111": "线虫,蛔虫",
|
| 114 |
+
"112": "海螺",
|
| 115 |
+
"113": "蜗牛",
|
| 116 |
+
"114": "鼻涕虫",
|
| 117 |
+
"115": "海参",
|
| 118 |
+
"116": "石鳖",
|
| 119 |
+
"117": "鹦鹉螺",
|
| 120 |
+
"118": "珍宝蟹",
|
| 121 |
+
"119": "石蟹",
|
| 122 |
+
"120": "招潮蟹",
|
| 123 |
+
"121": "帝王蟹,阿拉斯加蟹,阿拉斯加帝王蟹",
|
| 124 |
+
"122": "美国龙虾,缅因州龙虾",
|
| 125 |
+
"123": "大螯虾",
|
| 126 |
+
"124": "小龙虾",
|
| 127 |
+
"125": "寄居蟹",
|
| 128 |
+
"126": "等足目动物(明虾和螃蟹近亲)",
|
| 129 |
+
"127": "白鹳",
|
| 130 |
+
"128": "黑鹳",
|
| 131 |
+
"129": "鹭",
|
| 132 |
+
"130": "火烈鸟",
|
| 133 |
+
"131": "小蓝鹭",
|
| 134 |
+
"132": "美国鹭,大白鹭",
|
| 135 |
+
"133": "麻鸦",
|
| 136 |
+
"134": "鹤",
|
| 137 |
+
"135": "秧鹤",
|
| 138 |
+
"136": "欧洲水鸡,紫水鸡",
|
| 139 |
+
"137": "沼泽泥母鸡,水母鸡",
|
| 140 |
+
"138": "鸨",
|
| 141 |
+
"139": "红翻石鹬",
|
| 142 |
+
"140": "红背鹬,黑腹滨鹬",
|
| 143 |
+
"141": "红脚鹬",
|
| 144 |
+
"142": "半蹼鹬",
|
| 145 |
+
"143": "蛎鹬",
|
| 146 |
+
"144": "鹈鹕",
|
| 147 |
+
"145": "国王企鹅",
|
| 148 |
+
"146": "信天翁,大海鸟",
|
| 149 |
+
"147": "灰鲸",
|
| 150 |
+
"148": "杀人鲸,逆戟鲸,虎鲸",
|
| 151 |
+
"149": "海牛",
|
| 152 |
+
"150": "海狮",
|
| 153 |
+
"151": "奇瓦瓦",
|
| 154 |
+
"152": "日本猎犬",
|
| 155 |
+
"153": "马尔济斯犬",
|
| 156 |
+
"154": "狮子狗",
|
| 157 |
+
"155": "西施犬",
|
| 158 |
+
"156": "布莱尼姆猎犬",
|
| 159 |
+
"157": "巴比狗",
|
| 160 |
+
"158": "玩具犬",
|
| 161 |
+
"159": "罗得西亚长背猎狗",
|
| 162 |
+
"160": "阿富汗猎犬",
|
| 163 |
+
"161": "猎犬",
|
| 164 |
+
"162": "比格犬,猎兔犬",
|
| 165 |
+
"163": "侦探犬",
|
| 166 |
+
"164": "蓝色快狗",
|
| 167 |
+
"165": "黑褐猎浣熊犬",
|
| 168 |
+
"166": "沃克猎犬",
|
| 169 |
+
"167": "英国猎狐犬",
|
| 170 |
+
"168": "美洲赤狗",
|
| 171 |
+
"169": "俄罗斯猎狼犬",
|
| 172 |
+
"170": "爱尔兰猎狼犬",
|
| 173 |
+
"171": "意大利灰狗",
|
| 174 |
+
"172": "惠比特犬",
|
| 175 |
+
"173": "依比沙猎犬",
|
| 176 |
+
"174": "挪威猎犬",
|
| 177 |
+
"175": "奥达猎犬,水獭猎犬",
|
| 178 |
+
"176": "沙克犬,瞪羚猎犬",
|
| 179 |
+
"177": "苏格兰猎鹿犬,猎鹿犬",
|
| 180 |
+
"178": "威玛猎犬",
|
| 181 |
+
"179": "斯塔福德郡牛头梗,斯塔福德郡斗牛梗",
|
| 182 |
+
"180": "美国斯塔福德郡梗,美国比特斗牛梗,斗牛梗",
|
| 183 |
+
"181": "贝德灵顿梗",
|
| 184 |
+
"182": "边境梗",
|
| 185 |
+
"183": "凯丽蓝梗",
|
| 186 |
+
"184": "爱尔兰梗",
|
| 187 |
+
"185": "诺福克梗",
|
| 188 |
+
"186": "诺维奇梗",
|
| 189 |
+
"187": "约克郡梗",
|
| 190 |
+
"188": "刚毛猎狐梗",
|
| 191 |
+
"189": "莱克兰梗",
|
| 192 |
+
"190": "锡利哈姆梗",
|
| 193 |
+
"191": "艾尔谷犬",
|
| 194 |
+
"192": "凯恩梗",
|
| 195 |
+
"193": "澳大利亚梗",
|
| 196 |
+
"194": "丹迪丁蒙梗",
|
| 197 |
+
"195": "波士顿梗",
|
| 198 |
+
"196": "迷你雪纳瑞犬",
|
| 199 |
+
"197": "巨型雪纳瑞犬",
|
| 200 |
+
"198": "标准雪纳瑞犬",
|
| 201 |
+
"199": "苏格兰梗",
|
| 202 |
+
"200": "西藏梗,菊花狗",
|
| 203 |
+
"201": "丝毛梗",
|
| 204 |
+
"202": "软毛麦色梗",
|
| 205 |
+
"203": "西高地白梗",
|
| 206 |
+
"204": "拉萨阿普索犬",
|
| 207 |
+
"205": "平毛寻回犬",
|
| 208 |
+
"206": "卷毛寻回犬",
|
| 209 |
+
"207": "金毛猎犬",
|
| 210 |
+
"208": "拉布拉多猎犬",
|
| 211 |
+
"209": "乞沙比克猎犬",
|
| 212 |
+
"210": "德国短毛猎犬",
|
| 213 |
+
"211": "维兹拉犬",
|
| 214 |
+
"212": "英国谍犬",
|
| 215 |
+
"213": "爱尔兰雪达犬,红色猎犬",
|
| 216 |
+
"214": "戈登雪达犬",
|
| 217 |
+
"215": "布列塔尼犬猎犬",
|
| 218 |
+
"216": "黄毛,黄毛猎犬",
|
| 219 |
+
"217": "英国史宾格犬",
|
| 220 |
+
"218": "威尔士史宾格犬",
|
| 221 |
+
"219": "可卡犬,英国可卡犬",
|
| 222 |
+
"220": "萨塞克斯猎犬",
|
| 223 |
+
"221": "爱尔兰水猎犬",
|
| 224 |
+
"222": "哥威斯犬",
|
| 225 |
+
"223": "舒柏奇犬",
|
| 226 |
+
"224": "比利时牧羊犬",
|
| 227 |
+
"225": "马里努阿犬",
|
| 228 |
+
"226": "伯瑞犬",
|
| 229 |
+
"227": "凯尔皮犬",
|
| 230 |
+
"228": "匈牙利牧羊犬",
|
| 231 |
+
"229": "老英国牧羊犬",
|
| 232 |
+
"230": "喜乐蒂牧羊犬",
|
| 233 |
+
"231": "牧羊犬",
|
| 234 |
+
"232": "边境牧羊犬",
|
| 235 |
+
"233": "法兰德斯牧牛狗",
|
| 236 |
+
"234": "罗特韦尔犬",
|
| 237 |
+
"235": "德国牧羊犬,德国警犬,阿尔萨斯",
|
| 238 |
+
"236": "多伯曼犬,杜宾犬",
|
| 239 |
+
"237": "迷你杜宾犬",
|
| 240 |
+
"238": "大瑞士山地犬",
|
| 241 |
+
"239": "伯恩山犬",
|
| 242 |
+
"240": "Appenzeller狗",
|
| 243 |
+
"241": "EntleBucher狗",
|
| 244 |
+
"242": "拳师狗",
|
| 245 |
+
"243": "斗牛獒",
|
| 246 |
+
"244": "藏獒",
|
| 247 |
+
"245": "法国斗牛犬",
|
| 248 |
+
"246": "大丹犬",
|
| 249 |
+
"247": "圣伯纳德狗",
|
| 250 |
+
"248": "爱斯基摩犬,哈士奇",
|
| 251 |
+
"249": "雪橇犬,阿拉斯加爱斯基摩狗",
|
| 252 |
+
"250": "哈士奇",
|
| 253 |
+
"251": "达尔马提亚,教练车狗",
|
| 254 |
+
"252": "狮毛狗",
|
| 255 |
+
"253": "巴辛吉狗",
|
| 256 |
+
"254": "哈巴狗,狮子狗",
|
| 257 |
+
"255": "莱昂贝格狗",
|
| 258 |
+
"256": "纽芬兰岛狗",
|
| 259 |
+
"257": "大白熊犬",
|
| 260 |
+
"258": "萨摩耶犬",
|
| 261 |
+
"259": "博美犬",
|
| 262 |
+
"260": "松狮,松狮",
|
| 263 |
+
"261": "荷兰卷尾狮毛狗",
|
| 264 |
+
"262": "布鲁塞尔格林芬犬",
|
| 265 |
+
"263": "彭布洛克威尔士科基犬",
|
| 266 |
+
"264": "威尔士柯基犬",
|
| 267 |
+
"265": "玩具贵宾犬",
|
| 268 |
+
"266": "迷你贵宾犬",
|
| 269 |
+
"267": "标准贵宾犬",
|
| 270 |
+
"268": "墨西哥无毛犬",
|
| 271 |
+
"269": "灰狼",
|
| 272 |
+
"270": "白狼,北极狼",
|
| 273 |
+
"271": "红太狼,鬃狼,犬犬鲁弗斯",
|
| 274 |
+
"272": "狼,草原狼,刷狼,郊狼",
|
| 275 |
+
"273": "澳洲野狗,澳大利亚野犬",
|
| 276 |
+
"274": "豺",
|
| 277 |
+
"275": "非洲猎犬,土狼犬",
|
| 278 |
+
"276": "鬣狗",
|
| 279 |
+
"277": "红狐狸",
|
| 280 |
+
"278": "沙狐",
|
| 281 |
+
"279": "北极狐狸,白狐狸",
|
| 282 |
+
"280": "灰狐狸",
|
| 283 |
+
"281": "虎斑猫",
|
| 284 |
+
"282": "山猫,虎猫",
|
| 285 |
+
"283": "波斯猫",
|
| 286 |
+
"284": "暹罗暹罗猫,",
|
| 287 |
+
"285": "埃及猫",
|
| 288 |
+
"286": "美洲狮,美洲豹",
|
| 289 |
+
"287": "猞猁,山猫",
|
| 290 |
+
"288": "豹子",
|
| 291 |
+
"289": "雪豹",
|
| 292 |
+
"290": "美洲虎",
|
| 293 |
+
"291": "狮子",
|
| 294 |
+
"292": "老虎",
|
| 295 |
+
"293": "猎豹",
|
| 296 |
+
"294": "棕熊",
|
| 297 |
+
"295": "美洲黑熊",
|
| 298 |
+
"296": "冰熊,北极熊",
|
| 299 |
+
"297": "懒熊",
|
| 300 |
+
"298": "猫鼬",
|
| 301 |
+
"299": "猫鼬,海猫",
|
| 302 |
+
"300": "虎甲虫",
|
| 303 |
+
"301": "瓢虫",
|
| 304 |
+
"302": "土鳖虫",
|
| 305 |
+
"303": "天牛",
|
| 306 |
+
"304": "龟甲虫",
|
| 307 |
+
"305": "粪甲虫",
|
| 308 |
+
"306": "犀牛甲虫",
|
| 309 |
+
"307": "象甲",
|
| 310 |
+
"308": "苍蝇",
|
| 311 |
+
"309": "蜜蜂",
|
| 312 |
+
"310": "蚂蚁",
|
| 313 |
+
"311": "蚱蜢",
|
| 314 |
+
"312": "蟋蟀",
|
| 315 |
+
"313": "竹节虫",
|
| 316 |
+
"314": "蟑螂",
|
| 317 |
+
"315": "螳螂",
|
| 318 |
+
"316": "蝉",
|
| 319 |
+
"317": "叶蝉",
|
| 320 |
+
"318": "草蜻蛉",
|
| 321 |
+
"319": "蜻蜓",
|
| 322 |
+
"320": "豆娘,蜻蛉",
|
| 323 |
+
"321": "优红蛱蝶",
|
| 324 |
+
"322": "小环蝴蝶",
|
| 325 |
+
"323": "君主蝴蝶,大斑蝶",
|
| 326 |
+
"324": "菜粉蝶",
|
| 327 |
+
"325": "白蝴蝶",
|
| 328 |
+
"326": "灰蝶",
|
| 329 |
+
"327": "海星",
|
| 330 |
+
"328": "海胆",
|
| 331 |
+
"329": "海参,海黄瓜",
|
| 332 |
+
"330": "野兔",
|
| 333 |
+
"331": "兔",
|
| 334 |
+
"332": "安哥拉兔",
|
| 335 |
+
"333": "仓鼠",
|
| 336 |
+
"334": "刺猬,豪猪,",
|
| 337 |
+
"335": "黑松鼠",
|
| 338 |
+
"336": "土拨鼠",
|
| 339 |
+
"337": "海狸",
|
| 340 |
+
"338": "豚鼠,豚鼠",
|
| 341 |
+
"339": "栗色马",
|
| 342 |
+
"340": "斑马",
|
| 343 |
+
"341": "猪",
|
| 344 |
+
"342": "野猪",
|
| 345 |
+
"343": "疣猪",
|
| 346 |
+
"344": "河马",
|
| 347 |
+
"345": "牛",
|
| 348 |
+
"346": "水牛,亚洲水牛",
|
| 349 |
+
"347": "野牛",
|
| 350 |
+
"348": "公羊",
|
| 351 |
+
"349": "大角羊,洛矶山大角羊",
|
| 352 |
+
"350": "山羊",
|
| 353 |
+
"351": "狷羚",
|
| 354 |
+
"352": "黑斑羚",
|
| 355 |
+
"353": "瞪羚",
|
| 356 |
+
"354": "阿拉伯单峰骆驼,骆驼",
|
| 357 |
+
"355": "羊驼",
|
| 358 |
+
"356": "黄鼠狼",
|
| 359 |
+
"357": "水貂",
|
| 360 |
+
"358": "臭猫",
|
| 361 |
+
"359": "黑足鼬",
|
| 362 |
+
"360": "水獭",
|
| 363 |
+
"361": "臭鼬,木猫",
|
| 364 |
+
"362": "獾",
|
| 365 |
+
"363": "犰狳",
|
| 366 |
+
"364": "树懒",
|
| 367 |
+
"365": "猩猩,婆罗洲猩猩",
|
| 368 |
+
"366": "大猩猩",
|
| 369 |
+
"367": "黑猩猩",
|
| 370 |
+
"368": "长臂猿",
|
| 371 |
+
"369": "合趾猿长臂猿,合趾猿",
|
| 372 |
+
"370": "长尾猴",
|
| 373 |
+
"371": "赤猴",
|
| 374 |
+
"372": "狒狒",
|
| 375 |
+
"373": "恒河猴,猕猴",
|
| 376 |
+
"374": "白头叶猴",
|
| 377 |
+
"375": "疣猴",
|
| 378 |
+
"376": "长鼻猴",
|
| 379 |
+
"377": "狨(美洲产小型长尾猴)",
|
| 380 |
+
"378": "卷尾猴",
|
| 381 |
+
"379": "吼猴",
|
| 382 |
+
"380": "伶猴",
|
| 383 |
+
"381": "蜘蛛猴",
|
| 384 |
+
"382": "松鼠猴",
|
| 385 |
+
"383": "马达加斯加环尾狐猴,鼠狐猴",
|
| 386 |
+
"384": "大狐猴,马达加斯加大狐猴",
|
| 387 |
+
"385": "印度大象,亚洲象",
|
| 388 |
+
"386": "非洲象,非洲象",
|
| 389 |
+
"387": "小熊猫",
|
| 390 |
+
"388": "大熊猫",
|
| 391 |
+
"389": "杖鱼",
|
| 392 |
+
"390": "鳗鱼",
|
| 393 |
+
"391": "银鲑,银鲑���",
|
| 394 |
+
"392": "三色刺蝶鱼",
|
| 395 |
+
"393": "海葵鱼",
|
| 396 |
+
"394": "鲟鱼",
|
| 397 |
+
"395": "雀鳝",
|
| 398 |
+
"396": "狮子鱼",
|
| 399 |
+
"397": "河豚",
|
| 400 |
+
"398": "算盘",
|
| 401 |
+
"399": "长袍",
|
| 402 |
+
"400": "学位袍",
|
| 403 |
+
"401": "手风琴",
|
| 404 |
+
"402": "原声吉他",
|
| 405 |
+
"403": "航空母舰",
|
| 406 |
+
"404": "客机",
|
| 407 |
+
"405": "飞艇",
|
| 408 |
+
"406": "祭坛",
|
| 409 |
+
"407": "救护车",
|
| 410 |
+
"408": "水陆两用车",
|
| 411 |
+
"409": "模拟时钟",
|
| 412 |
+
"410": "蜂房",
|
| 413 |
+
"411": "围裙",
|
| 414 |
+
"412": "垃圾桶",
|
| 415 |
+
"413": "攻击步枪,枪",
|
| 416 |
+
"414": "背包",
|
| 417 |
+
"415": "面包店,面包铺,",
|
| 418 |
+
"416": "平衡木",
|
| 419 |
+
"417": "热气球",
|
| 420 |
+
"418": "圆珠笔",
|
| 421 |
+
"419": "创可贴",
|
| 422 |
+
"420": "班卓琴",
|
| 423 |
+
"421": "栏杆,楼梯扶手",
|
| 424 |
+
"422": "杠铃",
|
| 425 |
+
"423": "理发师的椅子",
|
| 426 |
+
"424": "理发店",
|
| 427 |
+
"425": "牲口棚",
|
| 428 |
+
"426": "晴雨表",
|
| 429 |
+
"427": "圆筒",
|
| 430 |
+
"428": "园地小车,手推车",
|
| 431 |
+
"429": "棒球",
|
| 432 |
+
"430": "篮球",
|
| 433 |
+
"431": "婴儿床",
|
| 434 |
+
"432": "巴松管,低音管",
|
| 435 |
+
"433": "游泳帽",
|
| 436 |
+
"434": "沐浴毛巾",
|
| 437 |
+
"435": "浴缸,澡盆",
|
| 438 |
+
"436": "沙滩车,旅行车",
|
| 439 |
+
"437": "灯塔",
|
| 440 |
+
"438": "高脚杯",
|
| 441 |
+
"439": "熊皮高帽",
|
| 442 |
+
"440": "啤酒瓶",
|
| 443 |
+
"441": "啤酒杯",
|
| 444 |
+
"442": "钟塔",
|
| 445 |
+
"443": "(小儿用的)围嘴",
|
| 446 |
+
"444": "串联自行车,",
|
| 447 |
+
"445": "比基尼",
|
| 448 |
+
"446": "装订册",
|
| 449 |
+
"447": "双筒望远镜",
|
| 450 |
+
"448": "鸟舍",
|
| 451 |
+
"449": "船库",
|
| 452 |
+
"450": "雪橇",
|
| 453 |
+
"451": "饰扣式领带",
|
| 454 |
+
"452": "阔边女帽",
|
| 455 |
+
"453": "书橱",
|
| 456 |
+
"454": "书店,书摊",
|
| 457 |
+
"455": "瓶盖",
|
| 458 |
+
"456": "弓箭",
|
| 459 |
+
"457": "蝴蝶结领结",
|
| 460 |
+
"458": "铜制牌位",
|
| 461 |
+
"459": "奶罩",
|
| 462 |
+
"460": "防波堤,海堤",
|
| 463 |
+
"461": "铠甲",
|
| 464 |
+
"462": "扫帚",
|
| 465 |
+
"463": "桶",
|
| 466 |
+
"464": "扣环",
|
| 467 |
+
"465": "防弹背心",
|
| 468 |
+
"466": "动车,子弹头列车",
|
| 469 |
+
"467": "肉铺,肉菜市场",
|
| 470 |
+
"468": "出租车",
|
| 471 |
+
"469": "大锅",
|
| 472 |
+
"470": "蜡烛",
|
| 473 |
+
"471": "大炮",
|
| 474 |
+
"472": "独木舟",
|
| 475 |
+
"473": "开瓶器,开罐器",
|
| 476 |
+
"474": "开衫",
|
| 477 |
+
"475": "车镜",
|
| 478 |
+
"476": "旋转木马",
|
| 479 |
+
"477": "木匠的工具包,工具包",
|
| 480 |
+
"478": "纸箱",
|
| 481 |
+
"479": "车轮",
|
| 482 |
+
"480": "取款机,自动取款机",
|
| 483 |
+
"481": "盒式录音带",
|
| 484 |
+
"482": "卡带播放器",
|
| 485 |
+
"483": "城堡",
|
| 486 |
+
"484": "双体船",
|
| 487 |
+
"485": "CD播放器",
|
| 488 |
+
"486": "大提琴",
|
| 489 |
+
"487": "移动电话,手机",
|
| 490 |
+
"488": "铁链",
|
| 491 |
+
"489": "围栏",
|
| 492 |
+
"490": "链甲",
|
| 493 |
+
"491": "电锯,油锯",
|
| 494 |
+
"492": "箱子",
|
| 495 |
+
"493": "衣柜,洗脸台",
|
| 496 |
+
"494": "编钟,钟,锣",
|
| 497 |
+
"495": "中国橱柜",
|
| 498 |
+
"496": "圣诞袜",
|
| 499 |
+
"497": "教堂,教堂建筑",
|
| 500 |
+
"498": "电影院,剧场",
|
| 501 |
+
"499": "切肉刀,菜刀",
|
| 502 |
+
"500": "悬崖屋",
|
| 503 |
+
"501": "斗篷",
|
| 504 |
+
"502": "木屐,木鞋",
|
| 505 |
+
"503": "鸡尾酒调酒器",
|
| 506 |
+
"504": "咖啡杯",
|
| 507 |
+
"505": "咖啡壶",
|
| 508 |
+
"506": "螺旋结构(楼梯)",
|
| 509 |
+
"507": "组合锁",
|
| 510 |
+
"508": "电脑键盘,键盘",
|
| 511 |
+
"509": "糖果,糖果店",
|
| 512 |
+
"510": "集装箱船",
|
| 513 |
+
"511": "敞篷车",
|
| 514 |
+
"512": "开瓶器,瓶螺杆",
|
| 515 |
+
"513": "短号,喇叭",
|
| 516 |
+
"514": "牛仔靴",
|
| 517 |
+
"515": "牛仔帽",
|
| 518 |
+
"516": "摇篮",
|
| 519 |
+
"517": "起重机",
|
| 520 |
+
"518": "头盔",
|
| 521 |
+
"519": "板条箱",
|
| 522 |
+
"520": "小儿床",
|
| 523 |
+
"521": "砂锅",
|
| 524 |
+
"522": "槌球",
|
| 525 |
+
"523": "拐杖",
|
| 526 |
+
"524": "胸甲",
|
| 527 |
+
"525": "大坝,堤防",
|
| 528 |
+
"526": "书桌",
|
| 529 |
+
"527": "台式电脑",
|
| 530 |
+
"528": "有线电话",
|
| 531 |
+
"529": "尿布湿",
|
| 532 |
+
"530": "数字时钟",
|
| 533 |
+
"531": "数字手表",
|
| 534 |
+
"532": "餐桌板",
|
| 535 |
+
"533": "抹布",
|
| 536 |
+
"534": "洗碗机,洗碟机",
|
| 537 |
+
"535": "盘式制动器",
|
| 538 |
+
"536": "码头,船坞,码头设施",
|
| 539 |
+
"537": "狗拉雪橇",
|
| 540 |
+
"538": "圆顶",
|
| 541 |
+
"539": "门垫,垫子",
|
| 542 |
+
"540": "钻井平台,海上钻井",
|
| 543 |
+
"541": "鼓,乐器,鼓膜",
|
| 544 |
+
"542": "鼓槌",
|
| 545 |
+
"543": "哑铃",
|
| 546 |
+
"544": "荷兰烤箱",
|
| 547 |
+
"545": "电风扇,鼓风机",
|
| 548 |
+
"546": "电吉他",
|
| 549 |
+
"547": "电力机车",
|
| 550 |
+
"548": "电视,电视柜",
|
| 551 |
+
"549": "信封",
|
| 552 |
+
"550": "浓缩咖啡机",
|
| 553 |
+
"551": "扑面粉",
|
| 554 |
+
"552": "女用长围巾",
|
| 555 |
+
"553": "文件,文件柜,档案柜",
|
| 556 |
+
"554": "消防船",
|
| 557 |
+
"555": "消防车",
|
| 558 |
+
"556": "火炉栏",
|
| 559 |
+
"557": "旗杆",
|
| 560 |
+
"558": "长笛",
|
| 561 |
+
"559": "折叠椅",
|
| 562 |
+
"560": "橄榄球头盔",
|
| 563 |
+
"561": "叉车",
|
| 564 |
+
"562": "喷泉",
|
| 565 |
+
"563": "钢笔",
|
| 566 |
+
"564": "有四根帷柱的床",
|
| 567 |
+
"565": "运货车厢",
|
| 568 |
+
"566": "圆号,喇叭",
|
| 569 |
+
"567": "煎锅",
|
| 570 |
+
"568": "裘皮大衣",
|
| 571 |
+
"569": "垃圾车",
|
| 572 |
+
"570": "防毒面具,呼吸器",
|
| 573 |
+
"571": "汽油泵",
|
| 574 |
+
"572": "高脚杯",
|
| 575 |
+
"573": "卡丁车",
|
| 576 |
+
"574": "高尔夫球",
|
| 577 |
+
"575": "高尔夫球车",
|
| 578 |
+
"576": "狭长小船",
|
| 579 |
+
"577": "锣",
|
| 580 |
+
"578": "礼服",
|
| 581 |
+
"579": "钢琴",
|
| 582 |
+
"580": "温室,苗圃",
|
| 583 |
+
"581": "散热器格栅",
|
| 584 |
+
"582": "杂货店,食品市场",
|
| 585 |
+
"583": "断头台",
|
| 586 |
+
"584": "小发夹",
|
| 587 |
+
"585": "头发喷雾",
|
| 588 |
+
"586": "半履带装甲车",
|
| 589 |
+
"587": "锤子",
|
| 590 |
+
"588": "大篮子",
|
| 591 |
+
"589": "手摇鼓风机,吹风机",
|
| 592 |
+
"590": "手提电脑",
|
| 593 |
+
"591": "手帕",
|
| 594 |
+
"592": "硬盘",
|
| 595 |
+
"593": "口琴,口风琴",
|
| 596 |
+
"594": "竖琴",
|
| 597 |
+
"595": "收割机",
|
| 598 |
+
"596": "斧头",
|
| 599 |
+
"597": "手枪皮套",
|
| 600 |
+
"598": "家庭影院",
|
| 601 |
+
"599": "蜂窝",
|
| 602 |
+
"600": "钩爪",
|
| 603 |
+
"601": "衬裙",
|
| 604 |
+
"602": "单杠",
|
| 605 |
+
"603": "马车",
|
| 606 |
+
"604": "沙漏",
|
| 607 |
+
"605": "手机,iPad",
|
| 608 |
+
"606": "熨斗",
|
| 609 |
+
"607": "南瓜灯笼",
|
| 610 |
+
"608": "牛仔裤,蓝色牛仔裤",
|
| 611 |
+
"609": "吉普车",
|
| 612 |
+
"610": "运动衫,T恤",
|
| 613 |
+
"611": "拼图",
|
| 614 |
+
"612": "人力车",
|
| 615 |
+
"613": "操纵杆",
|
| 616 |
+
"614": "和服",
|
| 617 |
+
"615": "护膝",
|
| 618 |
+
"616": "蝴蝶结",
|
| 619 |
+
"617": "大褂,实验室外套",
|
| 620 |
+
"618": "长柄勺",
|
| 621 |
+
"619": "灯罩",
|
| 622 |
+
"620": "笔记本电脑",
|
| 623 |
+
"621": "割草机",
|
| 624 |
+
"622": "镜头盖",
|
| 625 |
+
"623": "开信刀,裁纸刀",
|
| 626 |
+
"624": "图书馆",
|
| 627 |
+
"625": "救生艇",
|
| 628 |
+
"626": "点火器,打火机",
|
| 629 |
+
"627": "豪华轿车",
|
| 630 |
+
"628": "远洋班轮",
|
| 631 |
+
"629": "唇膏,口红",
|
| 632 |
+
"630": "平底便鞋",
|
| 633 |
+
"631": "洗剂",
|
| 634 |
+
"632": "扬声器",
|
| 635 |
+
"633": "放大镜",
|
| 636 |
+
"634": "锯木厂",
|
| 637 |
+
"635": "磁罗盘",
|
| 638 |
+
"636": "邮袋",
|
| 639 |
+
"637": "信箱",
|
| 640 |
+
"638": "女游泳衣",
|
| 641 |
+
"639": "有肩带浴衣",
|
| 642 |
+
"640": "窨井盖",
|
| 643 |
+
"641": "沙球(一种打击乐器)",
|
| 644 |
+
"642": "马林巴木琴",
|
| 645 |
+
"643": "面膜",
|
| 646 |
+
"644": "火柴",
|
| 647 |
+
"645": "花柱",
|
| 648 |
+
"646": "迷宫",
|
| 649 |
+
"647": "量杯",
|
| 650 |
+
"648": "药箱",
|
| 651 |
+
"649": "巨石,巨石结构",
|
| 652 |
+
"650": "麦克风",
|
| 653 |
+
"651": "微波炉",
|
| 654 |
+
"652": "军装",
|
| 655 |
+
"653": "奶桶",
|
| 656 |
+
"654": "迷你巴士",
|
| 657 |
+
"655": "迷你裙",
|
| 658 |
+
"656": "面包车",
|
| 659 |
+
"657": "导弹",
|
| 660 |
+
"658": "连指手套",
|
| 661 |
+
"659": "搅拌钵",
|
| 662 |
+
"660": "活动房屋(由汽车拖拉的)",
|
| 663 |
+
"661": "T型发动机小汽车",
|
| 664 |
+
"662": "调制解调器",
|
| 665 |
+
"663": "修道院",
|
| 666 |
+
"664": "显示器",
|
| 667 |
+
"665": "电瓶车",
|
| 668 |
+
"666": "砂浆",
|
| 669 |
+
"667": "学士",
|
| 670 |
+
"668": "清真寺",
|
| 671 |
+
"669": "蚊帐",
|
| 672 |
+
"670": "摩托车",
|
| 673 |
+
"671": "山地自行车",
|
| 674 |
+
"672": "登山帐",
|
| 675 |
+
"673": "鼠标,电脑鼠标",
|
| 676 |
+
"674": "捕鼠器",
|
| 677 |
+
"675": "搬家车",
|
| 678 |
+
"676": "口套",
|
| 679 |
+
"677": "钉子",
|
| 680 |
+
"678": "颈托",
|
| 681 |
+
"679": "项链",
|
| 682 |
+
"680": "乳头(瓶)",
|
| 683 |
+
"681": "笔记本,笔记本电脑",
|
| 684 |
+
"682": "方尖碑",
|
| 685 |
+
"683": "双簧管",
|
| 686 |
+
"684": "陶笛,卵形笛",
|
| 687 |
+
"685": "里程表",
|
| 688 |
+
"686": "滤油器",
|
| 689 |
+
"687": "风琴,管风琴",
|
| 690 |
+
"688": "示波器",
|
| 691 |
+
"689": "罩裙",
|
| 692 |
+
"690": "牛车",
|
| 693 |
+
"691": "氧气面罩",
|
| 694 |
+
"692": "包装",
|
| 695 |
+
"693": "船桨",
|
| 696 |
+
"694": "明轮,桨轮",
|
| 697 |
+
"695": "挂锁,扣锁",
|
| 698 |
+
"696": "画笔",
|
| 699 |
+
"697": "睡衣",
|
| 700 |
+
"698": "宫殿",
|
| 701 |
+
"699": "排箫,鸣管",
|
| 702 |
+
"700": "纸巾",
|
| 703 |
+
"701": "降落伞",
|
| 704 |
+
"702": "双杠",
|
| 705 |
+
"703": "公园长椅",
|
| 706 |
+
"704": "停车收费表,停车计时器",
|
| 707 |
+
"705": "客车,教练车",
|
| 708 |
+
"706": "露台,阳台",
|
| 709 |
+
"707": "付费电话",
|
| 710 |
+
"708": "基座,基脚",
|
| 711 |
+
"709": "铅笔盒",
|
| 712 |
+
"710": "卷笔刀",
|
| 713 |
+
"711": "香水(瓶)",
|
| 714 |
+
"712": "培养皿",
|
| 715 |
+
"713": "复印机",
|
| 716 |
+
"714": "拨弦片,拨子",
|
| 717 |
+
"715": "尖顶头盔",
|
| 718 |
+
"716": "栅栏,栅栏",
|
| 719 |
+
"717": "皮卡,皮卡车",
|
| 720 |
+
"718": "桥墩",
|
| 721 |
+
"719": "存钱罐",
|
| 722 |
+
"720": "药瓶",
|
| 723 |
+
"721": "枕头",
|
| 724 |
+
"722": "乒乓球",
|
| 725 |
+
"723": "风车",
|
| 726 |
+
"724": "海盗船",
|
| 727 |
+
"725": "水罐",
|
| 728 |
+
"726": "木工刨",
|
| 729 |
+
"727": "天文馆",
|
| 730 |
+
"728": "塑料袋",
|
| 731 |
+
"729": "板架",
|
| 732 |
+
"730": "犁型铲雪机",
|
| 733 |
+
"731": "手压皮碗泵",
|
| 734 |
+
"732": "宝丽来相机",
|
| 735 |
+
"733": "电线杆",
|
| 736 |
+
"734": "警车,巡逻车",
|
| 737 |
+
"735": "雨披",
|
| 738 |
+
"736": "台球桌",
|
| 739 |
+
"737": "充气饮料瓶",
|
| 740 |
+
"738": "花盆",
|
| 741 |
+
"739": "陶工旋盘",
|
| 742 |
+
"740": "电钻",
|
| 743 |
+
"741": "祈祷垫,地毯",
|
| 744 |
+
"742": "打印机",
|
| 745 |
+
"743": "监狱",
|
| 746 |
+
"744": "炮弹,导弹",
|
| 747 |
+
"745": "投影仪",
|
| 748 |
+
"746": "冰球",
|
| 749 |
+
"747": "沙包,吊球",
|
| 750 |
+
"748": "钱包",
|
| 751 |
+
"749": "羽管笔",
|
| 752 |
+
"750": "被子",
|
| 753 |
+
"751": "赛车",
|
| 754 |
+
"752": "球拍",
|
| 755 |
+
"753": "散热器",
|
| 756 |
+
"754": "收音机",
|
| 757 |
+
"755": "射电望远镜,无线电反射器",
|
| 758 |
+
"756": "雨桶",
|
| 759 |
+
"757": "休闲车,房车",
|
| 760 |
+
"758": "卷轴,卷筒",
|
| 761 |
+
"759": "反射式照相机",
|
| 762 |
+
"760": "冰箱,冰柜",
|
| 763 |
+
"761": "遥控器",
|
| 764 |
+
"762": "餐厅,饮食店,食堂",
|
| 765 |
+
"763": "左轮手枪",
|
| 766 |
+
"764": "步枪",
|
| 767 |
+
"765": "摇椅",
|
| 768 |
+
"766": "电转烤肉架",
|
| 769 |
+
"767": "橡皮",
|
| 770 |
+
"768": "橄榄球",
|
| 771 |
+
"769": "直尺",
|
| 772 |
+
"770": "跑步鞋",
|
| 773 |
+
"771": "保险柜",
|
| 774 |
+
"772": "安全别针",
|
| 775 |
+
"773": "盐瓶(调味用)",
|
| 776 |
+
"774": "凉鞋",
|
| 777 |
+
"775": "纱笼,围裙",
|
| 778 |
+
"776": "萨克斯管",
|
| 779 |
+
"777": "剑鞘",
|
| 780 |
+
"778": "秤,称重机",
|
| 781 |
+
"779": "校车",
|
| 782 |
+
"780": "帆船",
|
| 783 |
+
"781": "记分牌",
|
| 784 |
+
"782": "屏幕",
|
| 785 |
+
"783": "螺丝",
|
| 786 |
+
"784": "螺丝刀",
|
| 787 |
+
"785": "安全带",
|
| 788 |
+
"786": "缝纫机",
|
| 789 |
+
"787": "盾牌,盾牌",
|
| 790 |
+
"788": "皮鞋店,鞋店",
|
| 791 |
+
"789": "障子",
|
| 792 |
+
"790": "购物篮",
|
| 793 |
+
"791": "购物车",
|
| 794 |
+
"792": "铁锹",
|
| 795 |
+
"793": "浴帽",
|
| 796 |
+
"794": "浴帘",
|
| 797 |
+
"795": "滑雪板",
|
| 798 |
+
"796": "滑雪面罩",
|
| 799 |
+
"797": "睡袋",
|
| 800 |
+
"798": "滑尺",
|
| 801 |
+
"799": "滑动门",
|
| 802 |
+
"800": "角子老虎机",
|
| 803 |
+
"801": "潜水通气管",
|
| 804 |
+
"802": "雪橇",
|
| 805 |
+
"803": "扫雪机,扫雪机",
|
| 806 |
+
"804": "皂液器",
|
| 807 |
+
"805": "足球",
|
| 808 |
+
"806": "袜子",
|
| 809 |
+
"807": "碟式太阳能,太阳能集热器,太阳能炉",
|
| 810 |
+
"808": "宽边帽",
|
| 811 |
+
"809": "汤碗",
|
| 812 |
+
"810": "空格键",
|
| 813 |
+
"811": "空间加热器",
|
| 814 |
+
"812": "航天飞机",
|
| 815 |
+
"813": "铲(搅拌或涂敷用的)",
|
| 816 |
+
"814": "快艇",
|
| 817 |
+
"815": "蜘蛛网",
|
| 818 |
+
"816": "纺锤,纱锭",
|
| 819 |
+
"817": "跑车",
|
| 820 |
+
"818": "聚光灯",
|
| 821 |
+
"819": "舞台",
|
| 822 |
+
"820": "蒸汽机车",
|
| 823 |
+
"821": "钢拱桥",
|
| 824 |
+
"822": "钢滚筒",
|
| 825 |
+
"823": "听诊器",
|
| 826 |
+
"824": "女用披肩",
|
| 827 |
+
"825": "石头墙",
|
| 828 |
+
"826": "秒表",
|
| 829 |
+
"827": "火炉",
|
| 830 |
+
"828": "过滤器",
|
| 831 |
+
"829": "有轨电车,电车",
|
| 832 |
+
"830": "担架",
|
| 833 |
+
"831": "沙发床",
|
| 834 |
+
"832": "佛塔",
|
| 835 |
+
"833": "潜艇,潜水艇",
|
| 836 |
+
"834": "套装,衣服",
|
| 837 |
+
"835": "日晷",
|
| 838 |
+
"836": "太阳镜",
|
| 839 |
+
"837": "太阳镜,墨镜",
|
| 840 |
+
"838": "防晒霜,防晒剂",
|
| 841 |
+
"839": "悬索桥",
|
| 842 |
+
"840": "拖把",
|
| 843 |
+
"841": "运动衫",
|
| 844 |
+
"842": "游泳裤",
|
| 845 |
+
"843": "秋千",
|
| 846 |
+
"844": "开关,电器开关",
|
| 847 |
+
"845": "注射器",
|
| 848 |
+
"846": "台灯",
|
| 849 |
+
"847": "坦克,装甲战车,装甲战斗车辆",
|
| 850 |
+
"848": "磁带播放器",
|
| 851 |
+
"849": "茶壶",
|
| 852 |
+
"850": "泰迪,泰迪熊",
|
| 853 |
+
"851": "电视",
|
| 854 |
+
"852": "网球",
|
| 855 |
+
"853": "茅草,茅草屋顶",
|
| 856 |
+
"854": "幕布,剧院的帷幕",
|
| 857 |
+
"855": "顶针",
|
| 858 |
+
"856": "脱粒机",
|
| 859 |
+
"857": "宝座",
|
| 860 |
+
"858": "瓦屋顶",
|
| 861 |
+
"859": "烤面包机",
|
| 862 |
+
"860": "烟草店,烟草",
|
| 863 |
+
"861": "马桶",
|
| 864 |
+
"862": "火炬",
|
| 865 |
+
"863": "图腾柱",
|
| 866 |
+
"864": "拖车,牵引车,清障车",
|
| 867 |
+
"865": "玩具店",
|
| 868 |
+
"866": "拖拉机",
|
| 869 |
+
"867": "拖车,铰接式卡车",
|
| 870 |
+
"868": "托盘",
|
| 871 |
+
"869": "风衣",
|
| 872 |
+
"870": "三轮车",
|
| 873 |
+
"871": "三体船",
|
| 874 |
+
"872": "三脚架",
|
| 875 |
+
"873": "凯旋门",
|
| 876 |
+
"874": "无轨电车",
|
| 877 |
+
"875": "长号",
|
| 878 |
+
"876": "浴盆,浴缸",
|
| 879 |
+
"877": "旋转式栅门",
|
| 880 |
+
"878": "打字机键盘",
|
| 881 |
+
"879": "伞",
|
| 882 |
+
"880": "独轮车",
|
| 883 |
+
"881": "直立式钢琴",
|
| 884 |
+
"882": "真空吸尘器",
|
| 885 |
+
"883": "花瓶",
|
| 886 |
+
"884": "拱顶",
|
| 887 |
+
"885": "天鹅绒",
|
| 888 |
+
"886": "自动售货机",
|
| 889 |
+
"887": "祭服",
|
| 890 |
+
"888": "高架桥",
|
| 891 |
+
"889": "小提琴,小提琴",
|
| 892 |
+
"890": "排球",
|
| 893 |
+
"891": "松饼机",
|
| 894 |
+
"892": "挂钟",
|
| 895 |
+
"893": "钱包,皮夹",
|
| 896 |
+
"894": "衣柜,壁橱",
|
| 897 |
+
"895": "军用飞机",
|
| 898 |
+
"896": "洗脸盆,洗手盆",
|
| 899 |
+
"897": "洗衣机,自动洗衣机",
|
| 900 |
+
"898": "水瓶",
|
| 901 |
+
"899": "水壶",
|
| 902 |
+
"900": "水塔",
|
| 903 |
+
"901": "威士忌壶",
|
| 904 |
+
"902": "哨子",
|
| 905 |
+
"903": "假发",
|
| 906 |
+
"904": "纱窗",
|
| 907 |
+
"905": "百叶窗",
|
| 908 |
+
"906": "温莎领带",
|
| 909 |
+
"907": "葡萄酒瓶",
|
| 910 |
+
"908": "飞机翅膀,飞机",
|
| 911 |
+
"909": "炒菜锅",
|
| 912 |
+
"910": "木制的勺子",
|
| 913 |
+
"911": "毛织品,羊绒",
|
| 914 |
+
"912": "栅栏,围栏",
|
| 915 |
+
"913": "沉船",
|
| 916 |
+
"914": "双桅船",
|
| 917 |
+
"915": "蒙古包",
|
| 918 |
+
"916": "网站,互联网网站",
|
| 919 |
+
"917": "漫画",
|
| 920 |
+
"918": "纵横字谜",
|
| 921 |
+
"919": "路标",
|
| 922 |
+
"920": "交通信号灯",
|
| 923 |
+
"921": "防尘罩,书皮",
|
| 924 |
+
"922": "菜单",
|
| 925 |
+
"923": "盘子",
|
| 926 |
+
"924": "鳄梨酱",
|
| 927 |
+
"925": "清汤",
|
| 928 |
+
"926": "罐焖土豆烧肉",
|
| 929 |
+
"927": "蛋糕",
|
| 930 |
+
"928": "冰淇淋",
|
| 931 |
+
"929": "雪糕,冰棍,冰棒",
|
| 932 |
+
"930": "法式面包",
|
| 933 |
+
"931": "百吉饼",
|
| 934 |
+
"932": "椒盐脆饼",
|
| 935 |
+
"933": "芝士汉堡",
|
| 936 |
+
"934": "热狗",
|
| 937 |
+
"935": "土豆泥",
|
| 938 |
+
"936": "结球甘蓝",
|
| 939 |
+
"937": "西兰花",
|
| 940 |
+
"938": "菜花",
|
| 941 |
+
"939": "绿皮密生西葫芦",
|
| 942 |
+
"940": "西葫芦",
|
| 943 |
+
"941": "小青南瓜",
|
| 944 |
+
"942": "南瓜",
|
| 945 |
+
"943": "黄瓜",
|
| 946 |
+
"944": "朝鲜蓟",
|
| 947 |
+
"945": "甜椒",
|
| 948 |
+
"946": "刺棘蓟",
|
| 949 |
+
"947": "蘑菇",
|
| 950 |
+
"948": "绿苹果",
|
| 951 |
+
"949": "草莓",
|
| 952 |
+
"950": "橘子",
|
| 953 |
+
"951": "柠檬",
|
| 954 |
+
"952": "无花果",
|
| 955 |
+
"953": "菠萝",
|
| 956 |
+
"954": "香蕉",
|
| 957 |
+
"955": "菠萝蜜",
|
| 958 |
+
"956": "蛋奶冻苹果",
|
| 959 |
+
"957": "石榴",
|
| 960 |
+
"958": "干草",
|
| 961 |
+
"959": "烤面条加干酪沙司",
|
| 962 |
+
"960": "巧克力酱,巧克力糖浆",
|
| 963 |
+
"961": "面团",
|
| 964 |
+
"962": "瑞士肉包,肉饼",
|
| 965 |
+
"963": "披萨,披萨饼",
|
| 966 |
+
"964": "馅饼",
|
| 967 |
+
"965": "卷饼",
|
| 968 |
+
"966": "红葡萄酒",
|
| 969 |
+
"967": "意大利浓咖啡",
|
| 970 |
+
"968": "杯子",
|
| 971 |
+
"969": "蛋酒",
|
| 972 |
+
"970": "高山",
|
| 973 |
+
"971": "泡泡",
|
| 974 |
+
"972": "悬崖",
|
| 975 |
+
"973": "珊瑚礁",
|
| 976 |
+
"974": "间歇泉",
|
| 977 |
+
"975": "湖边,湖岸",
|
| 978 |
+
"976": "海角",
|
| 979 |
+
"977": "沙洲,沙坝",
|
| 980 |
+
"978": "海滨,海岸",
|
| 981 |
+
"979": "峡谷",
|
| 982 |
+
"980": "火山",
|
| 983 |
+
"981": "棒球,棒球运动员",
|
| 984 |
+
"982": "新郎",
|
| 985 |
+
"983": "潜水员",
|
| 986 |
+
"984": "油菜",
|
| 987 |
+
"985": "雏菊",
|
| 988 |
+
"986": "杓兰",
|
| 989 |
+
"987": "玉米",
|
| 990 |
+
"988": "橡子",
|
| 991 |
+
"989": "玫瑰果",
|
| 992 |
+
"990": "七叶树果实",
|
| 993 |
+
"991": "珊瑚菌",
|
| 994 |
+
"992": "木耳",
|
| 995 |
+
"993": "鹿花菌",
|
| 996 |
+
"994": "鬼笔菌",
|
| 997 |
+
"995": "地星(菌类)",
|
| 998 |
+
"996": "多叶奇果菌",
|
| 999 |
+
"997": "牛肝菌",
|
| 1000 |
+
"998": "玉米穗",
|
| 1001 |
+
"999": "卫生纸"
|
| 1002 |
+
}
|
labels/id2label_en.json
ADDED
|
@@ -0,0 +1,1002 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"0": "tench, Tinca tinca",
|
| 3 |
+
"1": "goldfish, Carassius auratus",
|
| 4 |
+
"2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
|
| 5 |
+
"3": "tiger shark, Galeocerdo cuvieri",
|
| 6 |
+
"4": "hammerhead, hammerhead shark",
|
| 7 |
+
"5": "electric ray, crampfish, numbfish, torpedo",
|
| 8 |
+
"6": "stingray",
|
| 9 |
+
"7": "cock",
|
| 10 |
+
"8": "hen",
|
| 11 |
+
"9": "ostrich, Struthio camelus",
|
| 12 |
+
"10": "brambling, Fringilla montifringilla",
|
| 13 |
+
"11": "goldfinch, Carduelis carduelis",
|
| 14 |
+
"12": "house finch, linnet, Carpodacus mexicanus",
|
| 15 |
+
"13": "junco, snowbird",
|
| 16 |
+
"14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
|
| 17 |
+
"15": "robin, American robin, Turdus migratorius",
|
| 18 |
+
"16": "bulbul",
|
| 19 |
+
"17": "jay",
|
| 20 |
+
"18": "magpie",
|
| 21 |
+
"19": "chickadee",
|
| 22 |
+
"20": "water ouzel, dipper",
|
| 23 |
+
"21": "kite",
|
| 24 |
+
"22": "bald eagle, American eagle, Haliaeetus leucocephalus",
|
| 25 |
+
"23": "vulture",
|
| 26 |
+
"24": "great grey owl, great gray owl, Strix nebulosa",
|
| 27 |
+
"25": "European fire salamander, Salamandra salamandra",
|
| 28 |
+
"26": "common newt, Triturus vulgaris",
|
| 29 |
+
"27": "eft",
|
| 30 |
+
"28": "spotted salamander, Ambystoma maculatum",
|
| 31 |
+
"29": "axolotl, mud puppy, Ambystoma mexicanum",
|
| 32 |
+
"30": "bullfrog, Rana catesbeiana",
|
| 33 |
+
"31": "tree frog, tree-frog",
|
| 34 |
+
"32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
|
| 35 |
+
"33": "loggerhead, loggerhead turtle, Caretta caretta",
|
| 36 |
+
"34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
|
| 37 |
+
"35": "mud turtle",
|
| 38 |
+
"36": "terrapin",
|
| 39 |
+
"37": "box turtle, box tortoise",
|
| 40 |
+
"38": "banded gecko",
|
| 41 |
+
"39": "common iguana, iguana, Iguana iguana",
|
| 42 |
+
"40": "American chameleon, anole, Anolis carolinensis",
|
| 43 |
+
"41": "whiptail, whiptail lizard",
|
| 44 |
+
"42": "agama",
|
| 45 |
+
"43": "frilled lizard, Chlamydosaurus kingi",
|
| 46 |
+
"44": "alligator lizard",
|
| 47 |
+
"45": "Gila monster, Heloderma suspectum",
|
| 48 |
+
"46": "green lizard, Lacerta viridis",
|
| 49 |
+
"47": "African chameleon, Chamaeleo chamaeleon",
|
| 50 |
+
"48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
|
| 51 |
+
"49": "African crocodile, Nile crocodile, Crocodylus niloticus",
|
| 52 |
+
"50": "American alligator, Alligator mississipiensis",
|
| 53 |
+
"51": "triceratops",
|
| 54 |
+
"52": "thunder snake, worm snake, Carphophis amoenus",
|
| 55 |
+
"53": "ringneck snake, ring-necked snake, ring snake",
|
| 56 |
+
"54": "hognose snake, puff adder, sand viper",
|
| 57 |
+
"55": "green snake, grass snake",
|
| 58 |
+
"56": "king snake, kingsnake",
|
| 59 |
+
"57": "garter snake, grass snake",
|
| 60 |
+
"58": "water snake",
|
| 61 |
+
"59": "vine snake",
|
| 62 |
+
"60": "night snake, Hypsiglena torquata",
|
| 63 |
+
"61": "boa constrictor, Constrictor constrictor",
|
| 64 |
+
"62": "rock python, rock snake, Python sebae",
|
| 65 |
+
"63": "Indian cobra, Naja naja",
|
| 66 |
+
"64": "green mamba",
|
| 67 |
+
"65": "sea snake",
|
| 68 |
+
"66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
|
| 69 |
+
"67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
|
| 70 |
+
"68": "sidewinder, horned rattlesnake, Crotalus cerastes",
|
| 71 |
+
"69": "trilobite",
|
| 72 |
+
"70": "harvestman, daddy longlegs, Phalangium opilio",
|
| 73 |
+
"71": "scorpion",
|
| 74 |
+
"72": "black and gold garden spider, Argiope aurantia",
|
| 75 |
+
"73": "barn spider, Araneus cavaticus",
|
| 76 |
+
"74": "garden spider, Aranea diademata",
|
| 77 |
+
"75": "black widow, Latrodectus mactans",
|
| 78 |
+
"76": "tarantula",
|
| 79 |
+
"77": "wolf spider, hunting spider",
|
| 80 |
+
"78": "tick",
|
| 81 |
+
"79": "centipede",
|
| 82 |
+
"80": "black grouse",
|
| 83 |
+
"81": "ptarmigan",
|
| 84 |
+
"82": "ruffed grouse, partridge, Bonasa umbellus",
|
| 85 |
+
"83": "prairie chicken, prairie grouse, prairie fowl",
|
| 86 |
+
"84": "peacock",
|
| 87 |
+
"85": "quail",
|
| 88 |
+
"86": "partridge",
|
| 89 |
+
"87": "African grey, African gray, Psittacus erithacus",
|
| 90 |
+
"88": "macaw",
|
| 91 |
+
"89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
|
| 92 |
+
"90": "lorikeet",
|
| 93 |
+
"91": "coucal",
|
| 94 |
+
"92": "bee eater",
|
| 95 |
+
"93": "hornbill",
|
| 96 |
+
"94": "hummingbird",
|
| 97 |
+
"95": "jacamar",
|
| 98 |
+
"96": "toucan",
|
| 99 |
+
"97": "drake",
|
| 100 |
+
"98": "red-breasted merganser, Mergus serrator",
|
| 101 |
+
"99": "goose",
|
| 102 |
+
"100": "black swan, Cygnus atratus",
|
| 103 |
+
"101": "tusker",
|
| 104 |
+
"102": "echidna, spiny anteater, anteater",
|
| 105 |
+
"103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
|
| 106 |
+
"104": "wallaby, brush kangaroo",
|
| 107 |
+
"105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
|
| 108 |
+
"106": "wombat",
|
| 109 |
+
"107": "jellyfish",
|
| 110 |
+
"108": "sea anemone, anemone",
|
| 111 |
+
"109": "brain coral",
|
| 112 |
+
"110": "flatworm, platyhelminth",
|
| 113 |
+
"111": "nematode, nematode worm, roundworm",
|
| 114 |
+
"112": "conch",
|
| 115 |
+
"113": "snail",
|
| 116 |
+
"114": "slug",
|
| 117 |
+
"115": "sea slug, nudibranch",
|
| 118 |
+
"116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
|
| 119 |
+
"117": "chambered nautilus, pearly nautilus, nautilus",
|
| 120 |
+
"118": "Dungeness crab, Cancer magister",
|
| 121 |
+
"119": "rock crab, Cancer irroratus",
|
| 122 |
+
"120": "fiddler crab",
|
| 123 |
+
"121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
|
| 124 |
+
"122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
|
| 125 |
+
"123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
|
| 126 |
+
"124": "crayfish, crawfish, crawdad, crawdaddy",
|
| 127 |
+
"125": "hermit crab",
|
| 128 |
+
"126": "isopod",
|
| 129 |
+
"127": "white stork, Ciconia ciconia",
|
| 130 |
+
"128": "black stork, Ciconia nigra",
|
| 131 |
+
"129": "spoonbill",
|
| 132 |
+
"130": "flamingo",
|
| 133 |
+
"131": "little blue heron, Egretta caerulea",
|
| 134 |
+
"132": "American egret, great white heron, Egretta albus",
|
| 135 |
+
"133": "bittern",
|
| 136 |
+
"134": "crane",
|
| 137 |
+
"135": "limpkin, Aramus pictus",
|
| 138 |
+
"136": "European gallinule, Porphyrio porphyrio",
|
| 139 |
+
"137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
|
| 140 |
+
"138": "bustard",
|
| 141 |
+
"139": "ruddy turnstone, Arenaria interpres",
|
| 142 |
+
"140": "red-backed sandpiper, dunlin, Erolia alpina",
|
| 143 |
+
"141": "redshank, Tringa totanus",
|
| 144 |
+
"142": "dowitcher",
|
| 145 |
+
"143": "oystercatcher, oyster catcher",
|
| 146 |
+
"144": "pelican",
|
| 147 |
+
"145": "king penguin, Aptenodytes patagonica",
|
| 148 |
+
"146": "albatross, mollymawk",
|
| 149 |
+
"147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
|
| 150 |
+
"148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
|
| 151 |
+
"149": "dugong, Dugong dugon",
|
| 152 |
+
"150": "sea lion",
|
| 153 |
+
"151": "Chihuahua",
|
| 154 |
+
"152": "Japanese spaniel",
|
| 155 |
+
"153": "Maltese dog, Maltese terrier, Maltese",
|
| 156 |
+
"154": "Pekinese, Pekingese, Peke",
|
| 157 |
+
"155": "Shih-Tzu",
|
| 158 |
+
"156": "Blenheim spaniel",
|
| 159 |
+
"157": "papillon",
|
| 160 |
+
"158": "toy terrier",
|
| 161 |
+
"159": "Rhodesian ridgeback",
|
| 162 |
+
"160": "Afghan hound, Afghan",
|
| 163 |
+
"161": "basset, basset hound",
|
| 164 |
+
"162": "beagle",
|
| 165 |
+
"163": "bloodhound, sleuthhound",
|
| 166 |
+
"164": "bluetick",
|
| 167 |
+
"165": "black-and-tan coonhound",
|
| 168 |
+
"166": "Walker hound, Walker foxhound",
|
| 169 |
+
"167": "English foxhound",
|
| 170 |
+
"168": "redbone",
|
| 171 |
+
"169": "borzoi, Russian wolfhound",
|
| 172 |
+
"170": "Irish wolfhound",
|
| 173 |
+
"171": "Italian greyhound",
|
| 174 |
+
"172": "whippet",
|
| 175 |
+
"173": "Ibizan hound, Ibizan Podenco",
|
| 176 |
+
"174": "Norwegian elkhound, elkhound",
|
| 177 |
+
"175": "otterhound, otter hound",
|
| 178 |
+
"176": "Saluki, gazelle hound",
|
| 179 |
+
"177": "Scottish deerhound, deerhound",
|
| 180 |
+
"178": "Weimaraner",
|
| 181 |
+
"179": "Staffordshire bullterrier, Staffordshire bull terrier",
|
| 182 |
+
"180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
|
| 183 |
+
"181": "Bedlington terrier",
|
| 184 |
+
"182": "Border terrier",
|
| 185 |
+
"183": "Kerry blue terrier",
|
| 186 |
+
"184": "Irish terrier",
|
| 187 |
+
"185": "Norfolk terrier",
|
| 188 |
+
"186": "Norwich terrier",
|
| 189 |
+
"187": "Yorkshire terrier",
|
| 190 |
+
"188": "wire-haired fox terrier",
|
| 191 |
+
"189": "Lakeland terrier",
|
| 192 |
+
"190": "Sealyham terrier, Sealyham",
|
| 193 |
+
"191": "Airedale, Airedale terrier",
|
| 194 |
+
"192": "cairn, cairn terrier",
|
| 195 |
+
"193": "Australian terrier",
|
| 196 |
+
"194": "Dandie Dinmont, Dandie Dinmont terrier",
|
| 197 |
+
"195": "Boston bull, Boston terrier",
|
| 198 |
+
"196": "miniature schnauzer",
|
| 199 |
+
"197": "giant schnauzer",
|
| 200 |
+
"198": "standard schnauzer",
|
| 201 |
+
"199": "Scotch terrier, Scottish terrier, Scottie",
|
| 202 |
+
"200": "Tibetan terrier, chrysanthemum dog",
|
| 203 |
+
"201": "silky terrier, Sydney silky",
|
| 204 |
+
"202": "soft-coated wheaten terrier",
|
| 205 |
+
"203": "West Highland white terrier",
|
| 206 |
+
"204": "Lhasa, Lhasa apso",
|
| 207 |
+
"205": "flat-coated retriever",
|
| 208 |
+
"206": "curly-coated retriever",
|
| 209 |
+
"207": "golden retriever",
|
| 210 |
+
"208": "Labrador retriever",
|
| 211 |
+
"209": "Chesapeake Bay retriever",
|
| 212 |
+
"210": "German short-haired pointer",
|
| 213 |
+
"211": "vizsla, Hungarian pointer",
|
| 214 |
+
"212": "English setter",
|
| 215 |
+
"213": "Irish setter, red setter",
|
| 216 |
+
"214": "Gordon setter",
|
| 217 |
+
"215": "Brittany spaniel",
|
| 218 |
+
"216": "clumber, clumber spaniel",
|
| 219 |
+
"217": "English springer, English springer spaniel",
|
| 220 |
+
"218": "Welsh springer spaniel",
|
| 221 |
+
"219": "cocker spaniel, English cocker spaniel, cocker",
|
| 222 |
+
"220": "Sussex spaniel",
|
| 223 |
+
"221": "Irish water spaniel",
|
| 224 |
+
"222": "kuvasz",
|
| 225 |
+
"223": "schipperke",
|
| 226 |
+
"224": "groenendael",
|
| 227 |
+
"225": "malinois",
|
| 228 |
+
"226": "briard",
|
| 229 |
+
"227": "kelpie",
|
| 230 |
+
"228": "komondor",
|
| 231 |
+
"229": "Old English sheepdog, bobtail",
|
| 232 |
+
"230": "Shetland sheepdog, Shetland sheep dog, Shetland",
|
| 233 |
+
"231": "collie",
|
| 234 |
+
"232": "Border collie",
|
| 235 |
+
"233": "Bouvier des Flandres, Bouviers des Flandres",
|
| 236 |
+
"234": "Rottweiler",
|
| 237 |
+
"235": "German shepherd, German shepherd dog, German police dog, alsatian",
|
| 238 |
+
"236": "Doberman, Doberman pinscher",
|
| 239 |
+
"237": "miniature pinscher",
|
| 240 |
+
"238": "Greater Swiss Mountain dog",
|
| 241 |
+
"239": "Bernese mountain dog",
|
| 242 |
+
"240": "Appenzeller",
|
| 243 |
+
"241": "EntleBucher",
|
| 244 |
+
"242": "boxer",
|
| 245 |
+
"243": "bull mastiff",
|
| 246 |
+
"244": "Tibetan mastiff",
|
| 247 |
+
"245": "French bulldog",
|
| 248 |
+
"246": "Great Dane",
|
| 249 |
+
"247": "Saint Bernard, St Bernard",
|
| 250 |
+
"248": "Eskimo dog, husky",
|
| 251 |
+
"249": "malamute, malemute, Alaskan malamute",
|
| 252 |
+
"250": "Siberian husky",
|
| 253 |
+
"251": "dalmatian, coach dog, carriage dog",
|
| 254 |
+
"252": "affenpinscher, monkey pinscher, monkey dog",
|
| 255 |
+
"253": "basenji",
|
| 256 |
+
"254": "pug, pug-dog",
|
| 257 |
+
"255": "Leonberg",
|
| 258 |
+
"256": "Newfoundland, Newfoundland dog",
|
| 259 |
+
"257": "Great Pyrenees",
|
| 260 |
+
"258": "Samoyed, Samoyede",
|
| 261 |
+
"259": "Pomeranian",
|
| 262 |
+
"260": "chow, chow chow",
|
| 263 |
+
"261": "keeshond",
|
| 264 |
+
"262": "Brabancon griffon",
|
| 265 |
+
"263": "Pembroke, Pembroke Welsh corgi",
|
| 266 |
+
"264": "Cardigan, Cardigan Welsh corgi",
|
| 267 |
+
"265": "toy poodle",
|
| 268 |
+
"266": "miniature poodle",
|
| 269 |
+
"267": "standard poodle",
|
| 270 |
+
"268": "Mexican hairless",
|
| 271 |
+
"269": "timber wolf, grey wolf, gray wolf, Canis lupus",
|
| 272 |
+
"270": "white wolf, Arctic wolf, Canis lupus tundrarum",
|
| 273 |
+
"271": "red wolf, maned wolf, Canis rufus, Canis niger",
|
| 274 |
+
"272": "coyote, prairie wolf, brush wolf, Canis latrans",
|
| 275 |
+
"273": "dingo, warrigal, warragal, Canis dingo",
|
| 276 |
+
"274": "dhole, Cuon alpinus",
|
| 277 |
+
"275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
|
| 278 |
+
"276": "hyena, hyaena",
|
| 279 |
+
"277": "red fox, Vulpes vulpes",
|
| 280 |
+
"278": "kit fox, Vulpes macrotis",
|
| 281 |
+
"279": "Arctic fox, white fox, Alopex lagopus",
|
| 282 |
+
"280": "grey fox, gray fox, Urocyon cinereoargenteus",
|
| 283 |
+
"281": "tabby, tabby cat",
|
| 284 |
+
"282": "tiger cat",
|
| 285 |
+
"283": "Persian cat",
|
| 286 |
+
"284": "Siamese cat, Siamese",
|
| 287 |
+
"285": "Egyptian cat",
|
| 288 |
+
"286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
|
| 289 |
+
"287": "lynx, catamount",
|
| 290 |
+
"288": "leopard, Panthera pardus",
|
| 291 |
+
"289": "snow leopard, ounce, Panthera uncia",
|
| 292 |
+
"290": "jaguar, panther, Panthera onca, Felis onca",
|
| 293 |
+
"291": "lion, king of beasts, Panthera leo",
|
| 294 |
+
"292": "tiger, Panthera tigris",
|
| 295 |
+
"293": "cheetah, chetah, Acinonyx jubatus",
|
| 296 |
+
"294": "brown bear, bruin, Ursus arctos",
|
| 297 |
+
"295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
|
| 298 |
+
"296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
|
| 299 |
+
"297": "sloth bear, Melursus ursinus, Ursus ursinus",
|
| 300 |
+
"298": "mongoose",
|
| 301 |
+
"299": "meerkat, mierkat",
|
| 302 |
+
"300": "tiger beetle",
|
| 303 |
+
"301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
|
| 304 |
+
"302": "ground beetle, carabid beetle",
|
| 305 |
+
"303": "long-horned beetle, longicorn, longicorn beetle",
|
| 306 |
+
"304": "leaf beetle, chrysomelid",
|
| 307 |
+
"305": "dung beetle",
|
| 308 |
+
"306": "rhinoceros beetle",
|
| 309 |
+
"307": "weevil",
|
| 310 |
+
"308": "fly",
|
| 311 |
+
"309": "bee",
|
| 312 |
+
"310": "ant, emmet, pismire",
|
| 313 |
+
"311": "grasshopper, hopper",
|
| 314 |
+
"312": "cricket",
|
| 315 |
+
"313": "walking stick, walkingstick, stick insect",
|
| 316 |
+
"314": "cockroach, roach",
|
| 317 |
+
"315": "mantis, mantid",
|
| 318 |
+
"316": "cicada, cicala",
|
| 319 |
+
"317": "leafhopper",
|
| 320 |
+
"318": "lacewing, lacewing fly",
|
| 321 |
+
"319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
|
| 322 |
+
"320": "damselfly",
|
| 323 |
+
"321": "admiral",
|
| 324 |
+
"322": "ringlet, ringlet butterfly",
|
| 325 |
+
"323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
|
| 326 |
+
"324": "cabbage butterfly",
|
| 327 |
+
"325": "sulphur butterfly, sulfur butterfly",
|
| 328 |
+
"326": "lycaenid, lycaenid butterfly",
|
| 329 |
+
"327": "starfish, sea star",
|
| 330 |
+
"328": "sea urchin",
|
| 331 |
+
"329": "sea cucumber, holothurian",
|
| 332 |
+
"330": "wood rabbit, cottontail, cottontail rabbit",
|
| 333 |
+
"331": "hare",
|
| 334 |
+
"332": "Angora, Angora rabbit",
|
| 335 |
+
"333": "hamster",
|
| 336 |
+
"334": "porcupine, hedgehog",
|
| 337 |
+
"335": "fox squirrel, eastern fox squirrel, Sciurus niger",
|
| 338 |
+
"336": "marmot",
|
| 339 |
+
"337": "beaver",
|
| 340 |
+
"338": "guinea pig, Cavia cobaya",
|
| 341 |
+
"339": "sorrel",
|
| 342 |
+
"340": "zebra",
|
| 343 |
+
"341": "hog, pig, grunter, squealer, Sus scrofa",
|
| 344 |
+
"342": "wild boar, boar, Sus scrofa",
|
| 345 |
+
"343": "warthog",
|
| 346 |
+
"344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
|
| 347 |
+
"345": "ox",
|
| 348 |
+
"346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
|
| 349 |
+
"347": "bison",
|
| 350 |
+
"348": "ram, tup",
|
| 351 |
+
"349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
|
| 352 |
+
"350": "ibex, Capra ibex",
|
| 353 |
+
"351": "hartebeest",
|
| 354 |
+
"352": "impala, Aepyceros melampus",
|
| 355 |
+
"353": "gazelle",
|
| 356 |
+
"354": "Arabian camel, dromedary, Camelus dromedarius",
|
| 357 |
+
"355": "llama",
|
| 358 |
+
"356": "weasel",
|
| 359 |
+
"357": "mink",
|
| 360 |
+
"358": "polecat, fitch, foulmart, foumart, Mustela putorius",
|
| 361 |
+
"359": "black-footed ferret, ferret, Mustela nigripes",
|
| 362 |
+
"360": "otter",
|
| 363 |
+
"361": "skunk, polecat, wood pussy",
|
| 364 |
+
"362": "badger",
|
| 365 |
+
"363": "armadillo",
|
| 366 |
+
"364": "three-toed sloth, ai, Bradypus tridactylus",
|
| 367 |
+
"365": "orangutan, orang, orangutang, Pongo pygmaeus",
|
| 368 |
+
"366": "gorilla, Gorilla gorilla",
|
| 369 |
+
"367": "chimpanzee, chimp, Pan troglodytes",
|
| 370 |
+
"368": "gibbon, Hylobates lar",
|
| 371 |
+
"369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
|
| 372 |
+
"370": "guenon, guenon monkey",
|
| 373 |
+
"371": "patas, hussar monkey, Erythrocebus patas",
|
| 374 |
+
"372": "baboon",
|
| 375 |
+
"373": "macaque",
|
| 376 |
+
"374": "langur",
|
| 377 |
+
"375": "colobus, colobus monkey",
|
| 378 |
+
"376": "proboscis monkey, Nasalis larvatus",
|
| 379 |
+
"377": "marmoset",
|
| 380 |
+
"378": "capuchin, ringtail, Cebus capucinus",
|
| 381 |
+
"379": "howler monkey, howler",
|
| 382 |
+
"380": "titi, titi monkey",
|
| 383 |
+
"381": "spider monkey, Ateles geoffroyi",
|
| 384 |
+
"382": "squirrel monkey, Saimiri sciureus",
|
| 385 |
+
"383": "Madagascar cat, ring-tailed lemur, Lemur catta",
|
| 386 |
+
"384": "indri, indris, Indri indri, Indri brevicaudatus",
|
| 387 |
+
"385": "Indian elephant, Elephas maximus",
|
| 388 |
+
"386": "African elephant, Loxodonta africana",
|
| 389 |
+
"387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
|
| 390 |
+
"388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
|
| 391 |
+
"389": "barracouta, snoek",
|
| 392 |
+
"390": "eel",
|
| 393 |
+
"391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
|
| 394 |
+
"392": "rock beauty, Holocanthus tricolor",
|
| 395 |
+
"393": "anemone fish",
|
| 396 |
+
"394": "sturgeon",
|
| 397 |
+
"395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
|
| 398 |
+
"396": "lionfish",
|
| 399 |
+
"397": "puffer, pufferfish, blowfish, globefish",
|
| 400 |
+
"398": "abacus",
|
| 401 |
+
"399": "abaya",
|
| 402 |
+
"400": "academic gown, academic robe, judge robe",
|
| 403 |
+
"401": "accordion, piano accordion, squeeze box",
|
| 404 |
+
"402": "acoustic guitar",
|
| 405 |
+
"403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
|
| 406 |
+
"404": "airliner",
|
| 407 |
+
"405": "airship, dirigible",
|
| 408 |
+
"406": "altar",
|
| 409 |
+
"407": "ambulance",
|
| 410 |
+
"408": "amphibian, amphibious vehicle",
|
| 411 |
+
"409": "analog clock",
|
| 412 |
+
"410": "apiary, bee house",
|
| 413 |
+
"411": "apron",
|
| 414 |
+
"412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
|
| 415 |
+
"413": "assault rifle, assault gun",
|
| 416 |
+
"414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
|
| 417 |
+
"415": "bakery, bakeshop, bakehouse",
|
| 418 |
+
"416": "balance beam, beam",
|
| 419 |
+
"417": "balloon",
|
| 420 |
+
"418": "ballpoint, ballpoint pen, ballpen, Biro",
|
| 421 |
+
"419": "Band Aid",
|
| 422 |
+
"420": "banjo",
|
| 423 |
+
"421": "bannister, banister, balustrade, balusters, handrail",
|
| 424 |
+
"422": "barbell",
|
| 425 |
+
"423": "barber chair",
|
| 426 |
+
"424": "barbershop",
|
| 427 |
+
"425": "barn",
|
| 428 |
+
"426": "barometer",
|
| 429 |
+
"427": "barrel, cask",
|
| 430 |
+
"428": "barrow, garden cart, lawn cart, wheelbarrow",
|
| 431 |
+
"429": "baseball",
|
| 432 |
+
"430": "basketball",
|
| 433 |
+
"431": "bassinet",
|
| 434 |
+
"432": "bassoon",
|
| 435 |
+
"433": "bathing cap, swimming cap",
|
| 436 |
+
"434": "bath towel",
|
| 437 |
+
"435": "bathtub, bathing tub, bath, tub",
|
| 438 |
+
"436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
|
| 439 |
+
"437": "beacon, lighthouse, beacon light, pharos",
|
| 440 |
+
"438": "beaker",
|
| 441 |
+
"439": "bearskin, busby, shako",
|
| 442 |
+
"440": "beer bottle",
|
| 443 |
+
"441": "beer glass",
|
| 444 |
+
"442": "bell cote, bell cot",
|
| 445 |
+
"443": "bib",
|
| 446 |
+
"444": "bicycle-built-for-two, tandem bicycle, tandem",
|
| 447 |
+
"445": "bikini, two-piece",
|
| 448 |
+
"446": "binder, ring-binder",
|
| 449 |
+
"447": "binoculars, field glasses, opera glasses",
|
| 450 |
+
"448": "birdhouse",
|
| 451 |
+
"449": "boathouse",
|
| 452 |
+
"450": "bobsled, bobsleigh, bob",
|
| 453 |
+
"451": "bolo tie, bolo, bola tie, bola",
|
| 454 |
+
"452": "bonnet, poke bonnet",
|
| 455 |
+
"453": "bookcase",
|
| 456 |
+
"454": "bookshop, bookstore, bookstall",
|
| 457 |
+
"455": "bottlecap",
|
| 458 |
+
"456": "bow",
|
| 459 |
+
"457": "bow tie, bow-tie, bowtie",
|
| 460 |
+
"458": "brass, memorial tablet, plaque",
|
| 461 |
+
"459": "brassiere, bra, bandeau",
|
| 462 |
+
"460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
|
| 463 |
+
"461": "breastplate, aegis, egis",
|
| 464 |
+
"462": "broom",
|
| 465 |
+
"463": "bucket, pail",
|
| 466 |
+
"464": "buckle",
|
| 467 |
+
"465": "bulletproof vest",
|
| 468 |
+
"466": "bullet train, bullet",
|
| 469 |
+
"467": "butcher shop, meat market",
|
| 470 |
+
"468": "cab, hack, taxi, taxicab",
|
| 471 |
+
"469": "caldron, cauldron",
|
| 472 |
+
"470": "candle, taper, wax light",
|
| 473 |
+
"471": "cannon",
|
| 474 |
+
"472": "canoe",
|
| 475 |
+
"473": "can opener, tin opener",
|
| 476 |
+
"474": "cardigan",
|
| 477 |
+
"475": "car mirror",
|
| 478 |
+
"476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
|
| 479 |
+
"477": "carpenters kit, tool kit",
|
| 480 |
+
"478": "carton",
|
| 481 |
+
"479": "car wheel",
|
| 482 |
+
"480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
|
| 483 |
+
"481": "cassette",
|
| 484 |
+
"482": "cassette player",
|
| 485 |
+
"483": "castle",
|
| 486 |
+
"484": "catamaran",
|
| 487 |
+
"485": "CD player",
|
| 488 |
+
"486": "cello, violoncello",
|
| 489 |
+
"487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
|
| 490 |
+
"488": "chain",
|
| 491 |
+
"489": "chainlink fence",
|
| 492 |
+
"490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
|
| 493 |
+
"491": "chain saw, chainsaw",
|
| 494 |
+
"492": "chest",
|
| 495 |
+
"493": "chiffonier, commode",
|
| 496 |
+
"494": "chime, bell, gong",
|
| 497 |
+
"495": "china cabinet, china closet",
|
| 498 |
+
"496": "Christmas stocking",
|
| 499 |
+
"497": "church, church building",
|
| 500 |
+
"498": "cinema, movie theater, movie theatre, movie house, picture palace",
|
| 501 |
+
"499": "cleaver, meat cleaver, chopper",
|
| 502 |
+
"500": "cliff dwelling",
|
| 503 |
+
"501": "cloak",
|
| 504 |
+
"502": "clog, geta, patten, sabot",
|
| 505 |
+
"503": "cocktail shaker",
|
| 506 |
+
"504": "coffee mug",
|
| 507 |
+
"505": "coffeepot",
|
| 508 |
+
"506": "coil, spiral, volute, whorl, helix",
|
| 509 |
+
"507": "combination lock",
|
| 510 |
+
"508": "computer keyboard, keypad",
|
| 511 |
+
"509": "confectionery, confectionary, candy store",
|
| 512 |
+
"510": "container ship, containership, container vessel",
|
| 513 |
+
"511": "convertible",
|
| 514 |
+
"512": "corkscrew, bottle screw",
|
| 515 |
+
"513": "cornet, horn, trumpet, trump",
|
| 516 |
+
"514": "cowboy boot",
|
| 517 |
+
"515": "cowboy hat, ten-gallon hat",
|
| 518 |
+
"516": "cradle",
|
| 519 |
+
"517": "crane",
|
| 520 |
+
"518": "crash helmet",
|
| 521 |
+
"519": "crate",
|
| 522 |
+
"520": "crib, cot",
|
| 523 |
+
"521": "Crock Pot",
|
| 524 |
+
"522": "croquet ball",
|
| 525 |
+
"523": "crutch",
|
| 526 |
+
"524": "cuirass",
|
| 527 |
+
"525": "dam, dike, dyke",
|
| 528 |
+
"526": "desk",
|
| 529 |
+
"527": "desktop computer",
|
| 530 |
+
"528": "dial telephone, dial phone",
|
| 531 |
+
"529": "diaper, nappy, napkin",
|
| 532 |
+
"530": "digital clock",
|
| 533 |
+
"531": "digital watch",
|
| 534 |
+
"532": "dining table, board",
|
| 535 |
+
"533": "dishrag, dishcloth",
|
| 536 |
+
"534": "dishwasher, dish washer, dishwashing machine",
|
| 537 |
+
"535": "disk brake, disc brake",
|
| 538 |
+
"536": "dock, dockage, docking facility",
|
| 539 |
+
"537": "dogsled, dog sled, dog sleigh",
|
| 540 |
+
"538": "dome",
|
| 541 |
+
"539": "doormat, welcome mat",
|
| 542 |
+
"540": "drilling platform, offshore rig",
|
| 543 |
+
"541": "drum, membranophone, tympan",
|
| 544 |
+
"542": "drumstick",
|
| 545 |
+
"543": "dumbbell",
|
| 546 |
+
"544": "Dutch oven",
|
| 547 |
+
"545": "electric fan, blower",
|
| 548 |
+
"546": "electric guitar",
|
| 549 |
+
"547": "electric locomotive",
|
| 550 |
+
"548": "entertainment center",
|
| 551 |
+
"549": "envelope",
|
| 552 |
+
"550": "espresso maker",
|
| 553 |
+
"551": "face powder",
|
| 554 |
+
"552": "feather boa, boa",
|
| 555 |
+
"553": "file, file cabinet, filing cabinet",
|
| 556 |
+
"554": "fireboat",
|
| 557 |
+
"555": "fire engine, fire truck",
|
| 558 |
+
"556": "fire screen, fireguard",
|
| 559 |
+
"557": "flagpole, flagstaff",
|
| 560 |
+
"558": "flute, transverse flute",
|
| 561 |
+
"559": "folding chair",
|
| 562 |
+
"560": "football helmet",
|
| 563 |
+
"561": "forklift",
|
| 564 |
+
"562": "fountain",
|
| 565 |
+
"563": "fountain pen",
|
| 566 |
+
"564": "four-poster",
|
| 567 |
+
"565": "freight car",
|
| 568 |
+
"566": "French horn, horn",
|
| 569 |
+
"567": "frying pan, frypan, skillet",
|
| 570 |
+
"568": "fur coat",
|
| 571 |
+
"569": "garbage truck, dustcart",
|
| 572 |
+
"570": "gasmask, respirator, gas helmet",
|
| 573 |
+
"571": "gas pump, gasoline pump, petrol pump, island dispenser",
|
| 574 |
+
"572": "goblet",
|
| 575 |
+
"573": "go-kart",
|
| 576 |
+
"574": "golf ball",
|
| 577 |
+
"575": "golfcart, golf cart",
|
| 578 |
+
"576": "gondola",
|
| 579 |
+
"577": "gong, tam-tam",
|
| 580 |
+
"578": "gown",
|
| 581 |
+
"579": "grand piano, grand",
|
| 582 |
+
"580": "greenhouse, nursery, glasshouse",
|
| 583 |
+
"581": "grille, radiator grille",
|
| 584 |
+
"582": "grocery store, grocery, food market, market",
|
| 585 |
+
"583": "guillotine",
|
| 586 |
+
"584": "hair slide",
|
| 587 |
+
"585": "hair spray",
|
| 588 |
+
"586": "half track",
|
| 589 |
+
"587": "hammer",
|
| 590 |
+
"588": "hamper",
|
| 591 |
+
"589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
|
| 592 |
+
"590": "hand-held computer, hand-held microcomputer",
|
| 593 |
+
"591": "handkerchief, hankie, hanky, hankey",
|
| 594 |
+
"592": "hard disc, hard disk, fixed disk",
|
| 595 |
+
"593": "harmonica, mouth organ, harp, mouth harp",
|
| 596 |
+
"594": "harp",
|
| 597 |
+
"595": "harvester, reaper",
|
| 598 |
+
"596": "hatchet",
|
| 599 |
+
"597": "holster",
|
| 600 |
+
"598": "home theater, home theatre",
|
| 601 |
+
"599": "honeycomb",
|
| 602 |
+
"600": "hook, claw",
|
| 603 |
+
"601": "hoopskirt, crinoline",
|
| 604 |
+
"602": "horizontal bar, high bar",
|
| 605 |
+
"603": "horse cart, horse-cart",
|
| 606 |
+
"604": "hourglass",
|
| 607 |
+
"605": "iPod",
|
| 608 |
+
"606": "iron, smoothing iron",
|
| 609 |
+
"607": "jack-o-lantern",
|
| 610 |
+
"608": "jean, blue jean, denim",
|
| 611 |
+
"609": "jeep, landrover",
|
| 612 |
+
"610": "jersey, T-shirt, tee shirt",
|
| 613 |
+
"611": "jigsaw puzzle",
|
| 614 |
+
"612": "jinrikisha, ricksha, rickshaw",
|
| 615 |
+
"613": "joystick",
|
| 616 |
+
"614": "kimono",
|
| 617 |
+
"615": "knee pad",
|
| 618 |
+
"616": "knot",
|
| 619 |
+
"617": "lab coat, laboratory coat",
|
| 620 |
+
"618": "ladle",
|
| 621 |
+
"619": "lampshade, lamp shade",
|
| 622 |
+
"620": "laptop, laptop computer",
|
| 623 |
+
"621": "lawn mower, mower",
|
| 624 |
+
"622": "lens cap, lens cover",
|
| 625 |
+
"623": "letter opener, paper knife, paperknife",
|
| 626 |
+
"624": "library",
|
| 627 |
+
"625": "lifeboat",
|
| 628 |
+
"626": "lighter, light, igniter, ignitor",
|
| 629 |
+
"627": "limousine, limo",
|
| 630 |
+
"628": "liner, ocean liner",
|
| 631 |
+
"629": "lipstick, lip rouge",
|
| 632 |
+
"630": "Loafer",
|
| 633 |
+
"631": "lotion",
|
| 634 |
+
"632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
|
| 635 |
+
"633": "loupe, jewelers loupe",
|
| 636 |
+
"634": "lumbermill, sawmill",
|
| 637 |
+
"635": "magnetic compass",
|
| 638 |
+
"636": "mailbag, postbag",
|
| 639 |
+
"637": "mailbox, letter box",
|
| 640 |
+
"638": "maillot",
|
| 641 |
+
"639": "maillot, tank suit",
|
| 642 |
+
"640": "manhole cover",
|
| 643 |
+
"641": "maraca",
|
| 644 |
+
"642": "marimba, xylophone",
|
| 645 |
+
"643": "mask",
|
| 646 |
+
"644": "matchstick",
|
| 647 |
+
"645": "maypole",
|
| 648 |
+
"646": "maze, labyrinth",
|
| 649 |
+
"647": "measuring cup",
|
| 650 |
+
"648": "medicine chest, medicine cabinet",
|
| 651 |
+
"649": "megalith, megalithic structure",
|
| 652 |
+
"650": "microphone, mike",
|
| 653 |
+
"651": "microwave, microwave oven",
|
| 654 |
+
"652": "military uniform",
|
| 655 |
+
"653": "milk can",
|
| 656 |
+
"654": "minibus",
|
| 657 |
+
"655": "miniskirt, mini",
|
| 658 |
+
"656": "minivan",
|
| 659 |
+
"657": "missile",
|
| 660 |
+
"658": "mitten",
|
| 661 |
+
"659": "mixing bowl",
|
| 662 |
+
"660": "mobile home, manufactured home",
|
| 663 |
+
"661": "Model T",
|
| 664 |
+
"662": "modem",
|
| 665 |
+
"663": "monastery",
|
| 666 |
+
"664": "monitor",
|
| 667 |
+
"665": "moped",
|
| 668 |
+
"666": "mortar",
|
| 669 |
+
"667": "mortarboard",
|
| 670 |
+
"668": "mosque",
|
| 671 |
+
"669": "mosquito net",
|
| 672 |
+
"670": "motor scooter, scooter",
|
| 673 |
+
"671": "mountain bike, all-terrain bike, off-roader",
|
| 674 |
+
"672": "mountain tent",
|
| 675 |
+
"673": "mouse, computer mouse",
|
| 676 |
+
"674": "mousetrap",
|
| 677 |
+
"675": "moving van",
|
| 678 |
+
"676": "muzzle",
|
| 679 |
+
"677": "nail",
|
| 680 |
+
"678": "neck brace",
|
| 681 |
+
"679": "necklace",
|
| 682 |
+
"680": "nipple",
|
| 683 |
+
"681": "notebook, notebook computer",
|
| 684 |
+
"682": "obelisk",
|
| 685 |
+
"683": "oboe, hautboy, hautbois",
|
| 686 |
+
"684": "ocarina, sweet potato",
|
| 687 |
+
"685": "odometer, hodometer, mileometer, milometer",
|
| 688 |
+
"686": "oil filter",
|
| 689 |
+
"687": "organ, pipe organ",
|
| 690 |
+
"688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
|
| 691 |
+
"689": "overskirt",
|
| 692 |
+
"690": "oxcart",
|
| 693 |
+
"691": "oxygen mask",
|
| 694 |
+
"692": "packet",
|
| 695 |
+
"693": "paddle, boat paddle",
|
| 696 |
+
"694": "paddlewheel, paddle wheel",
|
| 697 |
+
"695": "padlock",
|
| 698 |
+
"696": "paintbrush",
|
| 699 |
+
"697": "pajama, pyjama, pjs, jammies",
|
| 700 |
+
"698": "palace",
|
| 701 |
+
"699": "panpipe, pandean pipe, syrinx",
|
| 702 |
+
"700": "paper towel",
|
| 703 |
+
"701": "parachute, chute",
|
| 704 |
+
"702": "parallel bars, bars",
|
| 705 |
+
"703": "park bench",
|
| 706 |
+
"704": "parking meter",
|
| 707 |
+
"705": "passenger car, coach, carriage",
|
| 708 |
+
"706": "patio, terrace",
|
| 709 |
+
"707": "pay-phone, pay-station",
|
| 710 |
+
"708": "pedestal, plinth, footstall",
|
| 711 |
+
"709": "pencil box, pencil case",
|
| 712 |
+
"710": "pencil sharpener",
|
| 713 |
+
"711": "perfume, essence",
|
| 714 |
+
"712": "Petri dish",
|
| 715 |
+
"713": "photocopier",
|
| 716 |
+
"714": "pick, plectrum, plectron",
|
| 717 |
+
"715": "pickelhaube",
|
| 718 |
+
"716": "picket fence, paling",
|
| 719 |
+
"717": "pickup, pickup truck",
|
| 720 |
+
"718": "pier",
|
| 721 |
+
"719": "piggy bank, penny bank",
|
| 722 |
+
"720": "pill bottle",
|
| 723 |
+
"721": "pillow",
|
| 724 |
+
"722": "ping-pong ball",
|
| 725 |
+
"723": "pinwheel",
|
| 726 |
+
"724": "pirate, pirate ship",
|
| 727 |
+
"725": "pitcher, ewer",
|
| 728 |
+
"726": "plane, carpenters plane, woodworking plane",
|
| 729 |
+
"727": "planetarium",
|
| 730 |
+
"728": "plastic bag",
|
| 731 |
+
"729": "plate rack",
|
| 732 |
+
"730": "plow, plough",
|
| 733 |
+
"731": "plunger, plumbers helper",
|
| 734 |
+
"732": "Polaroid camera, Polaroid Land camera",
|
| 735 |
+
"733": "pole",
|
| 736 |
+
"734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
|
| 737 |
+
"735": "poncho",
|
| 738 |
+
"736": "pool table, billiard table, snooker table",
|
| 739 |
+
"737": "pop bottle, soda bottle",
|
| 740 |
+
"738": "pot, flowerpot",
|
| 741 |
+
"739": "potters wheel",
|
| 742 |
+
"740": "power drill",
|
| 743 |
+
"741": "prayer rug, prayer mat",
|
| 744 |
+
"742": "printer",
|
| 745 |
+
"743": "prison, prison house",
|
| 746 |
+
"744": "projectile, missile",
|
| 747 |
+
"745": "projector",
|
| 748 |
+
"746": "puck, hockey puck",
|
| 749 |
+
"747": "punching bag, punch bag, punching ball, punchball",
|
| 750 |
+
"748": "purse",
|
| 751 |
+
"749": "quill, quill pen",
|
| 752 |
+
"750": "quilt, comforter, comfort, puff",
|
| 753 |
+
"751": "racer, race car, racing car",
|
| 754 |
+
"752": "racket, racquet",
|
| 755 |
+
"753": "radiator",
|
| 756 |
+
"754": "radio, wireless",
|
| 757 |
+
"755": "radio telescope, radio reflector",
|
| 758 |
+
"756": "rain barrel",
|
| 759 |
+
"757": "recreational vehicle, RV, R.V.",
|
| 760 |
+
"758": "reel",
|
| 761 |
+
"759": "reflex camera",
|
| 762 |
+
"760": "refrigerator, icebox",
|
| 763 |
+
"761": "remote control, remote",
|
| 764 |
+
"762": "restaurant, eating house, eating place, eatery",
|
| 765 |
+
"763": "revolver, six-gun, six-shooter",
|
| 766 |
+
"764": "rifle",
|
| 767 |
+
"765": "rocking chair, rocker",
|
| 768 |
+
"766": "rotisserie",
|
| 769 |
+
"767": "rubber eraser, rubber, pencil eraser",
|
| 770 |
+
"768": "rugby ball",
|
| 771 |
+
"769": "rule, ruler",
|
| 772 |
+
"770": "running shoe",
|
| 773 |
+
"771": "safe",
|
| 774 |
+
"772": "safety pin",
|
| 775 |
+
"773": "saltshaker, salt shaker",
|
| 776 |
+
"774": "sandal",
|
| 777 |
+
"775": "sarong",
|
| 778 |
+
"776": "sax, saxophone",
|
| 779 |
+
"777": "scabbard",
|
| 780 |
+
"778": "scale, weighing machine",
|
| 781 |
+
"779": "school bus",
|
| 782 |
+
"780": "schooner",
|
| 783 |
+
"781": "scoreboard",
|
| 784 |
+
"782": "screen, CRT screen",
|
| 785 |
+
"783": "screw",
|
| 786 |
+
"784": "screwdriver",
|
| 787 |
+
"785": "seat belt, seatbelt",
|
| 788 |
+
"786": "sewing machine",
|
| 789 |
+
"787": "shield, buckler",
|
| 790 |
+
"788": "shoe shop, shoe-shop, shoe store",
|
| 791 |
+
"789": "shoji",
|
| 792 |
+
"790": "shopping basket",
|
| 793 |
+
"791": "shopping cart",
|
| 794 |
+
"792": "shovel",
|
| 795 |
+
"793": "shower cap",
|
| 796 |
+
"794": "shower curtain",
|
| 797 |
+
"795": "ski",
|
| 798 |
+
"796": "ski mask",
|
| 799 |
+
"797": "sleeping bag",
|
| 800 |
+
"798": "slide rule, slipstick",
|
| 801 |
+
"799": "sliding door",
|
| 802 |
+
"800": "slot, one-armed bandit",
|
| 803 |
+
"801": "snorkel",
|
| 804 |
+
"802": "snowmobile",
|
| 805 |
+
"803": "snowplow, snowplough",
|
| 806 |
+
"804": "soap dispenser",
|
| 807 |
+
"805": "soccer ball",
|
| 808 |
+
"806": "sock",
|
| 809 |
+
"807": "solar dish, solar collector, solar furnace",
|
| 810 |
+
"808": "sombrero",
|
| 811 |
+
"809": "soup bowl",
|
| 812 |
+
"810": "space bar",
|
| 813 |
+
"811": "space heater",
|
| 814 |
+
"812": "space shuttle",
|
| 815 |
+
"813": "spatula",
|
| 816 |
+
"814": "speedboat",
|
| 817 |
+
"815": "spider web, spiders web",
|
| 818 |
+
"816": "spindle",
|
| 819 |
+
"817": "sports car, sport car",
|
| 820 |
+
"818": "spotlight, spot",
|
| 821 |
+
"819": "stage",
|
| 822 |
+
"820": "steam locomotive",
|
| 823 |
+
"821": "steel arch bridge",
|
| 824 |
+
"822": "steel drum",
|
| 825 |
+
"823": "stethoscope",
|
| 826 |
+
"824": "stole",
|
| 827 |
+
"825": "stone wall",
|
| 828 |
+
"826": "stopwatch, stop watch",
|
| 829 |
+
"827": "stove",
|
| 830 |
+
"828": "strainer",
|
| 831 |
+
"829": "streetcar, tram, tramcar, trolley, trolley car",
|
| 832 |
+
"830": "stretcher",
|
| 833 |
+
"831": "studio couch, day bed",
|
| 834 |
+
"832": "stupa, tope",
|
| 835 |
+
"833": "submarine, pigboat, sub, U-boat",
|
| 836 |
+
"834": "suit, suit of clothes",
|
| 837 |
+
"835": "sundial",
|
| 838 |
+
"836": "sunglass",
|
| 839 |
+
"837": "sunglasses, dark glasses, shades",
|
| 840 |
+
"838": "sunscreen, sunblock, sun blocker",
|
| 841 |
+
"839": "suspension bridge",
|
| 842 |
+
"840": "swab, swob, mop",
|
| 843 |
+
"841": "sweatshirt",
|
| 844 |
+
"842": "swimming trunks, bathing trunks",
|
| 845 |
+
"843": "swing",
|
| 846 |
+
"844": "switch, electric switch, electrical switch",
|
| 847 |
+
"845": "syringe",
|
| 848 |
+
"846": "table lamp",
|
| 849 |
+
"847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
|
| 850 |
+
"848": "tape player",
|
| 851 |
+
"849": "teapot",
|
| 852 |
+
"850": "teddy, teddy bear",
|
| 853 |
+
"851": "television, television system",
|
| 854 |
+
"852": "tennis ball",
|
| 855 |
+
"853": "thatch, thatched roof",
|
| 856 |
+
"854": "theater curtain, theatre curtain",
|
| 857 |
+
"855": "thimble",
|
| 858 |
+
"856": "thresher, thrasher, threshing machine",
|
| 859 |
+
"857": "throne",
|
| 860 |
+
"858": "tile roof",
|
| 861 |
+
"859": "toaster",
|
| 862 |
+
"860": "tobacco shop, tobacconist shop, tobacconist",
|
| 863 |
+
"861": "toilet seat",
|
| 864 |
+
"862": "torch",
|
| 865 |
+
"863": "totem pole",
|
| 866 |
+
"864": "tow truck, tow car, wrecker",
|
| 867 |
+
"865": "toyshop",
|
| 868 |
+
"866": "tractor",
|
| 869 |
+
"867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
|
| 870 |
+
"868": "tray",
|
| 871 |
+
"869": "trench coat",
|
| 872 |
+
"870": "tricycle, trike, velocipede",
|
| 873 |
+
"871": "trimaran",
|
| 874 |
+
"872": "tripod",
|
| 875 |
+
"873": "triumphal arch",
|
| 876 |
+
"874": "trolleybus, trolley coach, trackless trolley",
|
| 877 |
+
"875": "trombone",
|
| 878 |
+
"876": "tub, vat",
|
| 879 |
+
"877": "turnstile",
|
| 880 |
+
"878": "typewriter keyboard",
|
| 881 |
+
"879": "umbrella",
|
| 882 |
+
"880": "unicycle, monocycle",
|
| 883 |
+
"881": "upright, upright piano",
|
| 884 |
+
"882": "vacuum, vacuum cleaner",
|
| 885 |
+
"883": "vase",
|
| 886 |
+
"884": "vault",
|
| 887 |
+
"885": "velvet",
|
| 888 |
+
"886": "vending machine",
|
| 889 |
+
"887": "vestment",
|
| 890 |
+
"888": "viaduct",
|
| 891 |
+
"889": "violin, fiddle",
|
| 892 |
+
"890": "volleyball",
|
| 893 |
+
"891": "waffle iron",
|
| 894 |
+
"892": "wall clock",
|
| 895 |
+
"893": "wallet, billfold, notecase, pocketbook",
|
| 896 |
+
"894": "wardrobe, closet, press",
|
| 897 |
+
"895": "warplane, military plane",
|
| 898 |
+
"896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
|
| 899 |
+
"897": "washer, automatic washer, washing machine",
|
| 900 |
+
"898": "water bottle",
|
| 901 |
+
"899": "water jug",
|
| 902 |
+
"900": "water tower",
|
| 903 |
+
"901": "whiskey jug",
|
| 904 |
+
"902": "whistle",
|
| 905 |
+
"903": "wig",
|
| 906 |
+
"904": "window screen",
|
| 907 |
+
"905": "window shade",
|
| 908 |
+
"906": "Windsor tie",
|
| 909 |
+
"907": "wine bottle",
|
| 910 |
+
"908": "wing",
|
| 911 |
+
"909": "wok",
|
| 912 |
+
"910": "wooden spoon",
|
| 913 |
+
"911": "wool, woolen, woollen",
|
| 914 |
+
"912": "worm fence, snake fence, snake-rail fence, Virginia fence",
|
| 915 |
+
"913": "wreck",
|
| 916 |
+
"914": "yawl",
|
| 917 |
+
"915": "yurt",
|
| 918 |
+
"916": "web site, website, internet site, site",
|
| 919 |
+
"917": "comic book",
|
| 920 |
+
"918": "crossword puzzle, crossword",
|
| 921 |
+
"919": "street sign",
|
| 922 |
+
"920": "traffic light, traffic signal, stoplight",
|
| 923 |
+
"921": "book jacket, dust cover, dust jacket, dust wrapper",
|
| 924 |
+
"922": "menu",
|
| 925 |
+
"923": "plate",
|
| 926 |
+
"924": "guacamole",
|
| 927 |
+
"925": "consomme",
|
| 928 |
+
"926": "hot pot, hotpot",
|
| 929 |
+
"927": "trifle",
|
| 930 |
+
"928": "ice cream, icecream",
|
| 931 |
+
"929": "ice lolly, lolly, lollipop, popsicle",
|
| 932 |
+
"930": "French loaf",
|
| 933 |
+
"931": "bagel, beigel",
|
| 934 |
+
"932": "pretzel",
|
| 935 |
+
"933": "cheeseburger",
|
| 936 |
+
"934": "hotdog, hot dog, red hot",
|
| 937 |
+
"935": "mashed potato",
|
| 938 |
+
"936": "head cabbage",
|
| 939 |
+
"937": "broccoli",
|
| 940 |
+
"938": "cauliflower",
|
| 941 |
+
"939": "zucchini, courgette",
|
| 942 |
+
"940": "spaghetti squash",
|
| 943 |
+
"941": "acorn squash",
|
| 944 |
+
"942": "butternut squash",
|
| 945 |
+
"943": "cucumber, cuke",
|
| 946 |
+
"944": "artichoke, globe artichoke",
|
| 947 |
+
"945": "bell pepper",
|
| 948 |
+
"946": "cardoon",
|
| 949 |
+
"947": "mushroom",
|
| 950 |
+
"948": "Granny Smith",
|
| 951 |
+
"949": "strawberry",
|
| 952 |
+
"950": "orange",
|
| 953 |
+
"951": "lemon",
|
| 954 |
+
"952": "fig",
|
| 955 |
+
"953": "pineapple, ananas",
|
| 956 |
+
"954": "banana",
|
| 957 |
+
"955": "jackfruit, jak, jack",
|
| 958 |
+
"956": "custard apple",
|
| 959 |
+
"957": "pomegranate",
|
| 960 |
+
"958": "hay",
|
| 961 |
+
"959": "carbonara",
|
| 962 |
+
"960": "chocolate sauce, chocolate syrup",
|
| 963 |
+
"961": "dough",
|
| 964 |
+
"962": "meat loaf, meatloaf",
|
| 965 |
+
"963": "pizza, pizza pie",
|
| 966 |
+
"964": "potpie",
|
| 967 |
+
"965": "burrito",
|
| 968 |
+
"966": "red wine",
|
| 969 |
+
"967": "espresso",
|
| 970 |
+
"968": "cup",
|
| 971 |
+
"969": "eggnog",
|
| 972 |
+
"970": "alp",
|
| 973 |
+
"971": "bubble",
|
| 974 |
+
"972": "cliff, drop, drop-off",
|
| 975 |
+
"973": "coral reef",
|
| 976 |
+
"974": "geyser",
|
| 977 |
+
"975": "lakeside, lakeshore",
|
| 978 |
+
"976": "promontory, headland, head, foreland",
|
| 979 |
+
"977": "sandbar, sand bar",
|
| 980 |
+
"978": "seashore, coast, seacoast, sea-coast",
|
| 981 |
+
"979": "valley, vale",
|
| 982 |
+
"980": "volcano",
|
| 983 |
+
"981": "ballplayer, baseball player",
|
| 984 |
+
"982": "groom, bridegroom",
|
| 985 |
+
"983": "scuba diver",
|
| 986 |
+
"984": "rapeseed",
|
| 987 |
+
"985": "daisy",
|
| 988 |
+
"986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
|
| 989 |
+
"987": "corn",
|
| 990 |
+
"988": "acorn",
|
| 991 |
+
"989": "hip, rose hip, rosehip",
|
| 992 |
+
"990": "buckeye, horse chestnut, conker",
|
| 993 |
+
"991": "coral fungus",
|
| 994 |
+
"992": "agaric",
|
| 995 |
+
"993": "gyromitra",
|
| 996 |
+
"994": "stinkhorn, carrion fungus",
|
| 997 |
+
"995": "earthstar",
|
| 998 |
+
"996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
|
| 999 |
+
"997": "bolete",
|
| 1000 |
+
"998": "ear, spike, capitulum",
|
| 1001 |
+
"999": "toilet tissue, toilet paper, bathroom tissue"
|
| 1002 |
+
}
|