diff --git a/.gitattributes b/.gitattributes index 53e5157ce0845c5ded8ccd278aac212ceaf73170..46930785edca92b88b7dc5cbcd99acf5958b4cd0 100644 --- a/.gitattributes +++ b/.gitattributes @@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text demo.png filter=lfs diff=lfs merge=lfs -text demo_images/jit_h32_test_inference.png filter=lfs diff=lfs merge=lfs -text +demo_images/jit_h32_final_test.png filter=lfs diff=lfs merge=lfs -text diff --git a/JiT-B-16/model_index.json b/JiT-B-16/model_index.json index 20edd748a56adb768b31321efe2e2a1855c71ab3..fa18cbcc32203c64fd174626ff563c5f533fb945 100644 --- a/JiT-B-16/model_index.json +++ b/JiT-B-16/model_index.json @@ -1,8 +1,15 @@ { - "_class_name": "JiTPipeline", + "_class_name": [ + "pipeline", + "JiTPipeline" + ], "_diffusers_version": "0.36.0", + "scheduler": [ + "scheduling_jit", + "JiTScheduler" + ], "transformer": [ - "jit_diffusers", + "jit_transformer_2d", "JiTTransformer2DModel" ] } diff --git a/JiT-B-16/pipeline.py b/JiT-B-16/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..6196a7db5f800a02de65d7b100cf3474cc67dcf7 --- /dev/null +++ b/JiT-B-16/pipeline.py @@ -0,0 +1,460 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib +import json +import sys +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import torch + +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.utils.torch_utils import randn_tensor + + +RECOMMENDED_NOISE_BY_SIZE = { + 256: 1.0, + 512: 2.0, +} + + +class JiTPipeline(DiffusionPipeline): + r""" + Pipeline for image generation using JiT (Just image Transformer). + + Parameters: + transformer ([`JiTTransformer2DModel`]): + A class-conditioned `JiTTransformer2DModel` to denoise the images. + scheduler ([`JiTScheduler`]): + Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler). + id2label (`dict[int, str]`, *optional*): + ImageNet class id to English label mapping. Values may contain comma-separated synonyms. + id2label_cn (`dict[int, str]`, *optional*): + ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms. + """ + + model_cpu_offload_seq = "transformer" + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs): + """Load a self-contained variant folder locally or from the Hub. + + Examples: + JiTPipeline.from_pretrained(".") + JiTPipeline.from_pretrained("./JiT-H-32") + DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True) + """ + repo_root = Path(__file__).resolve().parent + + if pretrained_model_name_or_path in (None, "", "."): + variant = repo_root + elif ( + isinstance(pretrained_model_name_or_path, str) + and "/" in pretrained_model_name_or_path + and not Path(pretrained_model_name_or_path).exists() + ): + from huggingface_hub import snapshot_download + + hub_kwargs = dict(kwargs.pop("hub_kwargs", {})) + if subfolder: + hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"]) + cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs) + variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir) + else: + variant = Path(pretrained_model_name_or_path) + if not variant.is_absolute(): + candidate = (Path.cwd() / variant).resolve() + variant = candidate if candidate.exists() else (repo_root / variant).resolve() + if subfolder: + variant = variant / subfolder + + model_kwargs = dict(kwargs) + inserted: List[str] = [] + + def _load_component(folder: str, module_name: str, class_name: str): + comp_dir = variant / folder + module_path = comp_dir / f"{module_name}.py" + has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists() + if not module_path.exists() or not has_weights: + return None + + comp_path = str(comp_dir) + if comp_path not in sys.path: + sys.path.insert(0, comp_path) + inserted.append(comp_path) + + module = importlib.import_module(module_name) + component_cls = getattr(module, class_name) + return component_cls.from_pretrained(str(comp_dir), **model_kwargs) + + try: + transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel") + scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler") + + if transformer is None: + raise ValueError(f"No loadable transformer found under {variant}") + + variant_path = str(variant) + id2label, id2label_cn = cls._load_labels_for_variant(variant_path) + + pipe = cls( + transformer=transformer, + scheduler=scheduler, + id2label=id2label, + id2label_cn=id2label_cn, + ) + if variant_path and hasattr(pipe, "register_to_config"): + pipe.register_to_config(_name_or_path=variant_path) + return pipe + finally: + for comp_path in inserted: + if comp_path in sys.path: + sys.path.remove(comp_path) + + def __init__( + self, + transformer, + scheduler, + id2label: Optional[Dict[int, str]] = None, + id2label_cn: Optional[Dict[int, str]] = None, + ): + super().__init__() + self.register_modules(transformer=transformer, scheduler=scheduler) + + self._id2label = id2label or {} + self._id2label_cn = id2label_cn or {} + self.labels = self._build_label2id(self._id2label) + self.labels_cn = self._build_label2id(self._id2label_cn) + + def _ensure_labels_loaded(self) -> None: + if self._id2label or self._id2label_cn: + return + loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None)) + if loaded_en: + self._id2label = loaded_en + self.labels = self._build_label2id(self._id2label) + if loaded_cn: + self._id2label_cn = loaded_cn + self.labels_cn = self._build_label2id(self._id2label_cn) + + @staticmethod + def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]: + if not variant_path: + return None + variant_dir = Path(variant_path).resolve() + labels_dir = variant_dir.parent / "labels" + return labels_dir if labels_dir.is_dir() else None + + @staticmethod + def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]: + filename = "id2label_en.json" if lang == "en" else "id2label_cn.json" + path = labels_dir / filename + if not path.exists(): + raise FileNotFoundError(path) + raw = json.loads(path.read_text(encoding="utf-8")) + return {int(key): value for key, value in raw.items()} + + @classmethod + def _load_labels_for_variant( + cls, + variant_path: Optional[str], + ) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]: + labels_dir = cls._labels_dir_for_variant(variant_path) + if labels_dir is None: + return None, None + try: + return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn") + except FileNotFoundError: + return None, None + + @staticmethod + def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]: + label2id: Dict[str, int] = {} + for class_id, value in id2label.items(): + for synonym in value.split(","): + synonym = synonym.strip() + if synonym: + label2id[synonym] = int(class_id) + return dict(sorted(label2id.items())) + + @property + def id2label(self) -> Dict[int, str]: + """ImageNet class id to English label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label + + @property + def id2label_cn(self) -> Dict[int, str]: + """ImageNet class id to Chinese label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label_cn + + def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]: + r""" + Map ImageNet label strings to class ids. + + Args: + label (`str` or `list[str]`): + One or more label strings. Each string must match a synonym in `id2label` (English) + or `id2label_cn` (Chinese). + lang (`str`, *optional*, defaults to `"en"`): + `"en"` uses English synonyms; `"cn"` uses Chinese synonyms. + """ + if lang not in ("en", "cn"): + raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.") + + self._ensure_labels_loaded() + label2id = self.labels if lang == "en" else self.labels_cn + if not label2id: + raise ValueError( + f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder." + ) + + if isinstance(label, str): + label = [label] + + missing = [item for item in label if item not in label2id] + if missing: + preview = ", ".join(list(label2id.keys())[:8]) + raise ValueError( + f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..." + ) + return [label2id[item] for item in label] + + def _normalize_class_labels( + self, + class_labels: Union[int, str, List[Union[int, str]]], + ) -> List[int]: + if isinstance(class_labels, int): + return [class_labels] + + if isinstance(class_labels, str): + return self.get_label_ids(class_labels) + + if class_labels and isinstance(class_labels[0], str): + self._ensure_labels_loaded() + if all(label in self.labels for label in class_labels): + return self.get_label_ids(class_labels, lang="en") + if all(label in self.labels_cn for label in class_labels): + return self.get_label_ids(class_labels, lang="cn") + raise ValueError( + "Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` " + "or Chinese synonyms from `pipe.labels_cn`." + ) + + return list(class_labels) + + def _predict_velocity( + self, + z_value: torch.Tensor, + t: torch.Tensor, + class_labels: torch.Tensor, + class_null: torch.Tensor, + do_classifier_free_guidance: bool, + guidance_scale: float, + guidance_interval_min: float, + guidance_interval_max: float, + ) -> torch.Tensor: + t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype) + if do_classifier_free_guidance: + z_in = torch.cat([z_value, z_value], dim=0) + labels = torch.cat([class_labels, class_null], dim=0) + else: + z_in = z_value + labels = class_labels + + t_batch = t.flatten().expand(z_in.shape[0]) + x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample + v = self.scheduler.velocity_from_prediction(z_in, x_pred, t) + + if not do_classifier_free_guidance: + return v + + v_cond, v_uncond = v.chunk(2, dim=0) + interval_mask = t < guidance_interval_max + if guidance_interval_min != 0.0: + interval_mask = interval_mask & (t > guidance_interval_min) + scale = torch.where( + interval_mask, + torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype), + torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype), + ) + return v_uncond + scale * (v_cond - v_uncond) + + def _run_sampler( + self, + latents: torch.Tensor, + class_labels: torch.Tensor, + class_null: torch.Tensor, + num_inference_steps: int, + do_classifier_free_guidance: bool, + guidance_scale: float, + guidance_interval_min: float, + guidance_interval_max: float, + sampling_method: str, + ) -> torch.Tensor: + device = latents.device + self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method) + timesteps = self.scheduler.timesteps + + for i in self.progress_bar(range(num_inference_steps - 1)): + t = timesteps[i] + t_next = timesteps[i + 1] + v = self._predict_velocity( + latents, + t, + class_labels, + class_null, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + ) + + if sampling_method == "heun": + latents_euler = latents + (t_next - t) * v + v_next = self._predict_velocity( + latents_euler, + t_next, + class_labels, + class_null, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + ) + latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample + else: + latents = self.scheduler.step(v, t, latents).prev_sample + + t = timesteps[-2] + t_next = timesteps[-1] + v = self._predict_velocity( + latents, + t, + class_labels, + class_null, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + ) + return latents + (t_next - t) * v + + @torch.inference_mode() + def __call__( + self, + class_labels: Union[int, str, List[Union[int, str]]], + guidance_scale: Optional[float] = None, + guidance_interval_min: float = 0.1, + guidance_interval_max: float = 1.0, + noise_scale: Optional[float] = None, + t_eps: Optional[float] = None, + sampling_method: Optional[str] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Generate class-conditional images. + + Args: + class_labels (`int`, `str`, `list[int]`, or `list[str]`): + ImageNet class indices or human-readable label strings (English or Chinese). + guidance_scale (`float`, *optional*): + Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`. + guidance_interval_min (`float`, defaults to `0.1`): + Lower bound of the CFG interval in flow time `t in [0, 1]`. + guidance_interval_max (`float`, defaults to `1.0`): + Upper bound of the CFG interval in flow time. + noise_scale (`float`, *optional*): + Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default). + t_eps (`float`, *optional*): + Epsilon clamp for the `1 - t` denominator (scheduler config by default). + sampling_method (`str`, *optional*): + `"heun"` or `"euler"`. Defaults to the scheduler config (`heun`). + generator (`torch.Generator`, *optional*): + RNG for reproducibility. + num_inference_steps (`int`, defaults to `50`): + Number of solver steps (at least 2). + output_type (`str`, *optional*, defaults to `"pil"`): + `"pil"`, `"np"`, or `"pt"`. + return_dict (`bool`, *optional*, defaults to `True`): + Return [`ImagePipelineOutput`] if True. + """ + solver = sampling_method or self.scheduler.config.solver + if solver not in {"heun", "euler"}: + raise ValueError("sampling_method must be one of: 'heun', 'euler'.") + if num_inference_steps < 2: + raise ValueError("num_inference_steps must be >= 2.") + + if t_eps is not None: + self.scheduler.register_to_config(t_eps=t_eps) + + class_label_ids = self._normalize_class_labels(class_labels) + do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0 + + batch_size = len(class_label_ids) + image_size = int(self.transformer.config.sample_size) + channels = int(self.transformer.config.in_channels) + null_class_val = int(self.transformer.config.num_classes) + + if guidance_scale is None: + guidance_scale = 1.0 + if noise_scale is None: + noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0) + + latents = ( + randn_tensor( + shape=(batch_size, channels, image_size, image_size), + generator=generator, + device=self._execution_device, + dtype=self.transformer.dtype, + ) + * noise_scale + ) + + class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1) + class_labels_t = class_labels_t.clamp(0, null_class_val - 1) + class_null = torch.full_like(class_labels_t, null_class_val) + + latents = self._run_sampler( + latents, + class_labels_t, + class_null, + num_inference_steps, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + solver, + ) + + images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu() + if output_type == "pt": + images = images_pt + elif output_type == "np": + images = images_pt.permute(0, 2, 3, 1).numpy() + else: + images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy()) + + self.maybe_free_model_hooks() + + if not return_dict: + return (images,) + return ImagePipelineOutput(images=images) diff --git a/JiT-B-16/scheduler/scheduler_config.json b/JiT-B-16/scheduler/scheduler_config.json new file mode 100644 index 0000000000000000000000000000000000000000..bb871ad8071d8be4699f5246288de0a17963a5c4 --- /dev/null +++ b/JiT-B-16/scheduler/scheduler_config.json @@ -0,0 +1,7 @@ +{ + "_class_name": "JiTScheduler", + "_diffusers_version": "0.36.0", + "num_train_timesteps": 1000, + "t_eps": 0.05, + "solver": "heun" +} diff --git a/JiT-B-16/scheduler/scheduling_jit.py b/JiT-B-16/scheduler/scheduling_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5c55890f3446c190ca847f204264b4b8cbbbbb --- /dev/null +++ b/JiT-B-16/scheduler/scheduling_jit.py @@ -0,0 +1,161 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput + + +@dataclass +class JiTSchedulerOutput(BaseOutput): + """ + Output class for the JiT scheduler's `step` function. + + Args: + prev_sample (`torch.Tensor`): + Updated sample after one solver step along the JiT flow-time grid. + """ + + prev_sample: torch.Tensor + + +class JiTScheduler(SchedulerMixin, ConfigMixin): + """ + Manual flow-matching scheduler for JiT checkpoints. + + Uses a linear flow-time grid `t in [0, 1]` (increasing), matching the official JiT + sampler. Velocity is `v = (x_pred - z) / (1 - t)`; integration is explicit Euler or + Heun along that grid. + """ + + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + t_eps: float = 5e-2, + solver: str = "heun", + ): + if solver not in {"heun", "euler"}: + raise ValueError("solver must be one of: 'heun', 'euler'.") + self.timesteps: Optional[torch.Tensor] = None + self.sigmas: Optional[List[float]] = None + self.num_inference_steps: Optional[int] = None + self._step_index: Optional[int] = None + + @property + def init_noise_sigma(self) -> float: + return 1.0 + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device, None] = None, + solver: Optional[str] = None, + ) -> None: + if num_inference_steps < 2: + raise ValueError("num_inference_steps must be >= 2.") + + self.num_inference_steps = num_inference_steps + self.timesteps = torch.linspace( + 0.0, + 1.0, + num_inference_steps + 1, + device=device, + dtype=torch.float32, + ) + sigma_grid = torch.linspace(0.0, 1.0, num_inference_steps, device=device, dtype=torch.float32) + self.sigmas = (1.0 - sigma_grid).tolist() + self._step_index = 0 + if solver is not None: + self.register_to_config(solver=solver) + + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + del timestep + return sample + + def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int: + if self._step_index is not None: + return self._step_index + if self.timesteps is None: + raise ValueError("Call `set_timesteps` before `step`.") + if timestep is None: + return 0 + t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0]) + matches = (self.timesteps - t_value).abs() < 1e-6 + if matches.any(): + return int(matches.nonzero(as_tuple=False)[0].item()) + return 0 + + def step( + self, + model_output: torch.Tensor, + timestep: Union[float, torch.Tensor, None], + sample: torch.Tensor, + model_output_next: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[JiTSchedulerOutput, Tuple[torch.Tensor]]: + """ + Integrate one step on the linear `t` grid. + + Args: + model_output (`torch.Tensor`): + Velocity `v = (x_pred - z) / (1 - t)` at the current time. + timestep (`float` or `torch.Tensor`, *optional*): + Current flow time `t`. When omitted, uses the internal step index. + sample (`torch.Tensor`): + Current noisy latent `z`. + model_output_next (`torch.Tensor`, *optional*): + Velocity at `t_next` (required for Heun intermediate steps). + """ + if self.timesteps is None: + raise ValueError("Call `set_timesteps` before `step`.") + + step_index = self._resolve_step_index(timestep) + if step_index >= len(self.timesteps) - 1: + raise ValueError("Scheduler has already reached the final timestep.") + + t = self.timesteps[step_index] + t_next = self.timesteps[step_index + 1] + dt = t_next - t + + if self.config.solver == "heun" and model_output_next is not None: + prev_sample = sample + dt * 0.5 * (model_output + model_output_next) + else: + prev_sample = sample + dt * model_output + + self._step_index = step_index + 1 + + if not return_dict: + return (prev_sample,) + return JiTSchedulerOutput(prev_sample=prev_sample) + + def velocity_from_prediction( + self, + sample: torch.Tensor, + x_pred: torch.Tensor, + timestep: Union[float, torch.Tensor], + ) -> torch.Tensor: + """Compute JiT velocity `v = (x_pred - z) / (1 - t)` with denominator clamp.""" + t = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype) + while t.ndim < sample.ndim: + t = t.unsqueeze(-1) + denom = (1.0 - t).clamp_min(self.config.t_eps) + return (x_pred - sample) / denom diff --git a/JiT-B-16/transformer/config.json b/JiT-B-16/transformer/config.json new file mode 100644 index 0000000000000000000000000000000000000000..f016269fa1e363f9b169209b73e3a202d91dc017 --- /dev/null +++ b/JiT-B-16/transformer/config.json @@ -0,0 +1,18 @@ +{ + "_class_name": "JiTTransformer2DModel", + "_diffusers_version": "0.36.0", + "attention_dropout": 0.0, + "bottleneck_dim": 128, + "dropout": 0.0, + "hidden_size": 768, + "in_channels": 3, + "in_context_len": 32, + "in_context_start": 4, + "mlp_ratio": 4.0, + "norm_eps": 1e-06, + "num_attention_heads": 12, + "num_classes": 1000, + "num_layers": 12, + "patch_size": 16, + "sample_size": 256 +} diff --git a/JiT-B-16/transformer/diffusion_pytorch_model.safetensors b/JiT-B-16/transformer/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..0170e303ba58a279934b8577ac29c5f050dfc249 --- /dev/null +++ b/JiT-B-16/transformer/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b68278f2e16a2842bbc17e7d38bc08d22475e1d748bb2e672a9b7e8aff5b4772 +size 525298808 diff --git a/JiT-B-16/transformer/jit_transformer_2d.py b/JiT-B-16/transformer/jit_transformer_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..3af0b8632931f4d42d78f8f9ced62d868e070e43 --- /dev/null +++ b/JiT-B-16/transformer/jit_transformer_2d.py @@ -0,0 +1,500 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import RMSNorm +from diffusers.utils import logging + +logger = logging.get_logger(__name__) + + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = {len(t.shape) for t in tensors} + if len(shape_lens) != 1: + raise ValueError("tensors must all have the same number of dimensions") + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*(list(t.shape) for t in tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + + if not all(len(set(t[1])) <= 2 for t in expandable_dims): + raise ValueError("invalid dimensions for broadcastable concatenation") + + max_dims = [(t[0], max(t[1])) for t in expandable_dims] + expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims] + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*(t[1] for t in expanded_dims))) + tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)] + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return x.view(*x.shape[:-2], -1) + + +class JiTRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + pt_seq_len=16, + ft_seq_len=None, + custom_freqs=None, + theta=10000, + num_cls_token=0, + ): + super().__init__() + if custom_freqs is not None: + freqs = custom_freqs + else: + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs = torch.einsum("..., f -> ... f", t, freqs) + freqs = freqs.repeat_interleave(2, dim=-1) + freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) + + if num_cls_token > 0: + freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D] + cos_img = freqs_flat.cos() + sin_img = freqs_flat.sin() + + # prepend in-context cls token + _, D = cos_img.shape + cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype) + sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype) + + self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False) + self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False) + else: + self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False) + self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False) + + def forward(self, t): + # Applied on (batch, seq_len, heads, head_dim) tensors from attention. + seq_len = t.shape[1] + freqs_cos = self.freqs_cos[:seq_len].to(t.dtype) + freqs_sin = self.freqs_sin[:seq_len].to(t.dtype) + + return t * freqs_cos[:, None, :] + rotate_half(t) * freqs_sin[:, None, :] + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class JiTPatchEmbed(nn.Module): + """Image to Patch Embedding with Bottleneck""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + + self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False) + self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias) + + def forward(self, x): + x = self.proj2(self.proj1(x)).flatten(2).transpose(1, 2) + return x + + +class JiTTimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t, dtype=None): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + if dtype is not None: + t_freq = t_freq.to(dtype=dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class JiTLabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. + """ + + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes + 1, hidden_size) + self.num_classes = num_classes + + def forward(self, labels): + embeddings = self.embedding_table(labels) + return embeddings + + +class JiTAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0.0, proj_drop=0.0, eps=1e-6): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.q_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + self.k_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rope=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = self.q_norm(q) + k = self.k_norm(k) + + if rope is not None: + q = q.transpose(1, 2) + k = k.transpose(1, 2) + q = rope(q) + k = rope(k) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + dropout_p = self.attn_drop if self.training else 0.0 + x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class JiTSwiGLUFFN(nn.Module): + def __init__(self, dim: int, hidden_dim: int, drop=0.0, bias=True) -> None: + super().__init__() + hidden_dim = int(hidden_dim * 2 / 3) + self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias) + self.w3 = nn.Linear(hidden_dim, dim, bias=bias) + self.ffn_dropout = nn.Dropout(drop) + + def forward(self, x): + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(self.ffn_dropout(hidden)) + + +class JiTBlock(nn.Module): + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, eps=1e-6): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=eps) + self.attn = JiTAttention( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + qk_norm=True, + attn_drop=attn_drop, + proj_drop=proj_drop, + eps=eps, + ) + self.norm2 = RMSNorm(hidden_size, eps=eps) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = JiTSwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop) + + self.act = nn.SiLU() + self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True) + + def forward(self, x, c, feat_rope=None): + # Apply activation + c = self.act(c) + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + + # Attention block + norm_x = self.norm1(x) + modulated_x = modulate(norm_x, shift_msa, scale_msa) + attn_out = self.attn(modulated_x, rope=feat_rope) + x = x + gate_msa.unsqueeze(1) * attn_out + + # MLP block + norm_x = self.norm2(x) + modulated_x = modulate(norm_x, shift_mlp, scale_mlp) + mlp_out = self.mlp(modulated_x) + x = x + gate_mlp.unsqueeze(1) * mlp_out + + return x + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0) + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}") + + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) + emb = np.concatenate([emb_h, emb_w], axis=1) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + if embed_dim % 2 != 0: + raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega + + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) + + emb_sin = np.sin(out) + emb_cos = np.cos(out) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) + return emb + + +class JiTTransformer2DModel(ModelMixin, ConfigMixin): + r""" + A 2D Transformer for pixel-space class-conditional generation with JiT + ([Back to Basics: Let Denoising Generative Models Denoise](https://arxiv.org/abs/2511.13720)). + + Parameters: + sample_size (`int`, defaults to `256`): + Input image resolution (height and width). + patch_size (`int`, defaults to `16`): + Patch size for the bottleneck patch embedder. + in_channels (`int`, defaults to `3`): + Number of input image channels. + hidden_size (`int`, defaults to `768`): + Transformer hidden dimension. + num_layers (`int`, defaults to `12`): + Number of JiT transformer blocks. + num_attention_heads (`int`, defaults to `12`): + Number of attention heads per block. + mlp_ratio (`float`, defaults to `4.0`): + MLP hidden dimension multiplier. + attention_dropout (`float`, defaults to `0.0`): + Attention dropout in the middle quarter of blocks. + dropout (`float`, defaults to `0.0`): + Projection dropout in the middle quarter of blocks. + num_classes (`int`, defaults to `1000`): + Number of class labels (null label uses index `num_classes` for CFG). + bottleneck_dim (`int`, defaults to `128`): + PCA bottleneck dimension in the patch embedder. + in_context_len (`int`, defaults to `32`): + Number of in-context class tokens prepended mid-network. + in_context_start (`int`, defaults to `4`): + Block index at which in-context tokens are inserted. + norm_eps (`float`, defaults to `1e-6`): + Epsilon for RMSNorm layers. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + + @register_to_config + def __init__( + self, + sample_size: int = 256, + patch_size: int = 16, + in_channels: int = 3, + hidden_size: int = 768, + num_layers: int = 12, + num_attention_heads: int = 12, + mlp_ratio: float = 4.0, + attention_dropout: float = 0.0, + dropout: float = 0.0, + num_classes: int = 1000, + bottleneck_dim: int = 128, + in_context_len: int = 32, + in_context_start: int = 4, + norm_eps: float = 1e-6, + ): + super().__init__() + self.sample_size = sample_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.in_context_len = in_context_len + self.in_context_start = in_context_start + self.norm_eps = norm_eps + self.gradient_checkpointing = False + + # Time and Class Embedding + self.t_embedder = JiTTimestepEmbedder(hidden_size) + self.y_embedder = JiTLabelEmbedder(num_classes, hidden_size) + + # Patch Embedding + self.x_embedder = JiTPatchEmbed( + img_size=sample_size, + patch_size=patch_size, + in_chans=in_channels, + pca_dim=bottleneck_dim, + embed_dim=hidden_size, + bias=True, + ) + + # Positional Embedding (Fixed Sin-Cos) + num_patches = self.x_embedder.num_patches + pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches**0.5)) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True) + + # In-context Embedding + if self.in_context_len > 0: + self.in_context_posemb = nn.Parameter(torch.zeros(1, self.in_context_len, hidden_size)) + + # RoPE + half_head_dim = hidden_size // num_attention_heads // 2 + hw_seq_len = sample_size // patch_size + self.feat_rope = JiTRotaryEmbedding(dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=0) + self.feat_rope_incontext = JiTRotaryEmbedding( + dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=self.in_context_len + ) + + # Blocks + self.blocks = nn.ModuleList( + [ + JiTBlock( + hidden_size, + num_attention_heads, + mlp_ratio=mlp_ratio, + attn_drop=attention_dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0, + proj_drop=dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0, + eps=norm_eps, + ) + for i in range(num_layers) + ] + ) + + # Final Layer + self.norm_final = RMSNorm(hidden_size, eps=norm_eps) + self.linear_final = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True) + self.act_final = nn.SiLU() + self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + class_labels: torch.LongTensor, + return_dict: bool = True, + ): + + t_emb = self.t_embedder(timestep, dtype=hidden_states.dtype) + y_emb = self.y_embedder(class_labels) + + # Ensure embeddings match hidden_states dtype + y_emb = y_emb.to(dtype=hidden_states.dtype) + + c = t_emb + y_emb + + # Patch Embed + x = self.x_embedder(hidden_states) + x = x + self.pos_embed.to(x.dtype) + + # Blocks + for i, block in enumerate(self.blocks): + if self.in_context_len > 0 and i == self.in_context_start: + in_context_tokens = y_emb.unsqueeze(1).repeat(1, self.in_context_len, 1) + in_context_tokens = in_context_tokens + self.in_context_posemb.to(in_context_tokens.dtype) + x = torch.cat([in_context_tokens, x], dim=1) + + rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext + + if self.training and self.gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + block, + x, + c, + rope, + use_reentrant=False, + ) + else: + x = block(x, c, feat_rope=rope) + + # Slice off in-context tokens + if self.in_context_len > 0: + x = x[:, self.in_context_len :] + + # Final Layer + c = self.act_final(c) + shift, scale = self.adaLN_modulation_final(c).chunk(2, dim=1) + + x = modulate(self.norm_final(x), shift, scale) + x = self.linear_final(x) + + # Unpatchify + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels)) + x = torch.einsum("nhwpqc->nchpwq", x) + output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size)) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/JiT-B-32/model_index.json b/JiT-B-32/model_index.json index 20edd748a56adb768b31321efe2e2a1855c71ab3..fa18cbcc32203c64fd174626ff563c5f533fb945 100644 --- a/JiT-B-32/model_index.json +++ b/JiT-B-32/model_index.json @@ -1,8 +1,15 @@ { - "_class_name": "JiTPipeline", + "_class_name": [ + "pipeline", + "JiTPipeline" + ], "_diffusers_version": "0.36.0", + "scheduler": [ + "scheduling_jit", + "JiTScheduler" + ], "transformer": [ - "jit_diffusers", + "jit_transformer_2d", "JiTTransformer2DModel" ] } diff --git a/JiT-B-32/pipeline.py b/JiT-B-32/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..6196a7db5f800a02de65d7b100cf3474cc67dcf7 --- /dev/null +++ b/JiT-B-32/pipeline.py @@ -0,0 +1,460 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib +import json +import sys +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import torch + +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.utils.torch_utils import randn_tensor + + +RECOMMENDED_NOISE_BY_SIZE = { + 256: 1.0, + 512: 2.0, +} + + +class JiTPipeline(DiffusionPipeline): + r""" + Pipeline for image generation using JiT (Just image Transformer). + + Parameters: + transformer ([`JiTTransformer2DModel`]): + A class-conditioned `JiTTransformer2DModel` to denoise the images. + scheduler ([`JiTScheduler`]): + Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler). + id2label (`dict[int, str]`, *optional*): + ImageNet class id to English label mapping. Values may contain comma-separated synonyms. + id2label_cn (`dict[int, str]`, *optional*): + ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms. + """ + + model_cpu_offload_seq = "transformer" + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs): + """Load a self-contained variant folder locally or from the Hub. + + Examples: + JiTPipeline.from_pretrained(".") + JiTPipeline.from_pretrained("./JiT-H-32") + DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True) + """ + repo_root = Path(__file__).resolve().parent + + if pretrained_model_name_or_path in (None, "", "."): + variant = repo_root + elif ( + isinstance(pretrained_model_name_or_path, str) + and "/" in pretrained_model_name_or_path + and not Path(pretrained_model_name_or_path).exists() + ): + from huggingface_hub import snapshot_download + + hub_kwargs = dict(kwargs.pop("hub_kwargs", {})) + if subfolder: + hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"]) + cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs) + variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir) + else: + variant = Path(pretrained_model_name_or_path) + if not variant.is_absolute(): + candidate = (Path.cwd() / variant).resolve() + variant = candidate if candidate.exists() else (repo_root / variant).resolve() + if subfolder: + variant = variant / subfolder + + model_kwargs = dict(kwargs) + inserted: List[str] = [] + + def _load_component(folder: str, module_name: str, class_name: str): + comp_dir = variant / folder + module_path = comp_dir / f"{module_name}.py" + has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists() + if not module_path.exists() or not has_weights: + return None + + comp_path = str(comp_dir) + if comp_path not in sys.path: + sys.path.insert(0, comp_path) + inserted.append(comp_path) + + module = importlib.import_module(module_name) + component_cls = getattr(module, class_name) + return component_cls.from_pretrained(str(comp_dir), **model_kwargs) + + try: + transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel") + scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler") + + if transformer is None: + raise ValueError(f"No loadable transformer found under {variant}") + + variant_path = str(variant) + id2label, id2label_cn = cls._load_labels_for_variant(variant_path) + + pipe = cls( + transformer=transformer, + scheduler=scheduler, + id2label=id2label, + id2label_cn=id2label_cn, + ) + if variant_path and hasattr(pipe, "register_to_config"): + pipe.register_to_config(_name_or_path=variant_path) + return pipe + finally: + for comp_path in inserted: + if comp_path in sys.path: + sys.path.remove(comp_path) + + def __init__( + self, + transformer, + scheduler, + id2label: Optional[Dict[int, str]] = None, + id2label_cn: Optional[Dict[int, str]] = None, + ): + super().__init__() + self.register_modules(transformer=transformer, scheduler=scheduler) + + self._id2label = id2label or {} + self._id2label_cn = id2label_cn or {} + self.labels = self._build_label2id(self._id2label) + self.labels_cn = self._build_label2id(self._id2label_cn) + + def _ensure_labels_loaded(self) -> None: + if self._id2label or self._id2label_cn: + return + loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None)) + if loaded_en: + self._id2label = loaded_en + self.labels = self._build_label2id(self._id2label) + if loaded_cn: + self._id2label_cn = loaded_cn + self.labels_cn = self._build_label2id(self._id2label_cn) + + @staticmethod + def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]: + if not variant_path: + return None + variant_dir = Path(variant_path).resolve() + labels_dir = variant_dir.parent / "labels" + return labels_dir if labels_dir.is_dir() else None + + @staticmethod + def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]: + filename = "id2label_en.json" if lang == "en" else "id2label_cn.json" + path = labels_dir / filename + if not path.exists(): + raise FileNotFoundError(path) + raw = json.loads(path.read_text(encoding="utf-8")) + return {int(key): value for key, value in raw.items()} + + @classmethod + def _load_labels_for_variant( + cls, + variant_path: Optional[str], + ) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]: + labels_dir = cls._labels_dir_for_variant(variant_path) + if labels_dir is None: + return None, None + try: + return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn") + except FileNotFoundError: + return None, None + + @staticmethod + def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]: + label2id: Dict[str, int] = {} + for class_id, value in id2label.items(): + for synonym in value.split(","): + synonym = synonym.strip() + if synonym: + label2id[synonym] = int(class_id) + return dict(sorted(label2id.items())) + + @property + def id2label(self) -> Dict[int, str]: + """ImageNet class id to English label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label + + @property + def id2label_cn(self) -> Dict[int, str]: + """ImageNet class id to Chinese label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label_cn + + def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]: + r""" + Map ImageNet label strings to class ids. + + Args: + label (`str` or `list[str]`): + One or more label strings. Each string must match a synonym in `id2label` (English) + or `id2label_cn` (Chinese). + lang (`str`, *optional*, defaults to `"en"`): + `"en"` uses English synonyms; `"cn"` uses Chinese synonyms. + """ + if lang not in ("en", "cn"): + raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.") + + self._ensure_labels_loaded() + label2id = self.labels if lang == "en" else self.labels_cn + if not label2id: + raise ValueError( + f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder." + ) + + if isinstance(label, str): + label = [label] + + missing = [item for item in label if item not in label2id] + if missing: + preview = ", ".join(list(label2id.keys())[:8]) + raise ValueError( + f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..." + ) + return [label2id[item] for item in label] + + def _normalize_class_labels( + self, + class_labels: Union[int, str, List[Union[int, str]]], + ) -> List[int]: + if isinstance(class_labels, int): + return [class_labels] + + if isinstance(class_labels, str): + return self.get_label_ids(class_labels) + + if class_labels and isinstance(class_labels[0], str): + self._ensure_labels_loaded() + if all(label in self.labels for label in class_labels): + return self.get_label_ids(class_labels, lang="en") + if all(label in self.labels_cn for label in class_labels): + return self.get_label_ids(class_labels, lang="cn") + raise ValueError( + "Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` " + "or Chinese synonyms from `pipe.labels_cn`." + ) + + return list(class_labels) + + def _predict_velocity( + self, + z_value: torch.Tensor, + t: torch.Tensor, + class_labels: torch.Tensor, + class_null: torch.Tensor, + do_classifier_free_guidance: bool, + guidance_scale: float, + guidance_interval_min: float, + guidance_interval_max: float, + ) -> torch.Tensor: + t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype) + if do_classifier_free_guidance: + z_in = torch.cat([z_value, z_value], dim=0) + labels = torch.cat([class_labels, class_null], dim=0) + else: + z_in = z_value + labels = class_labels + + t_batch = t.flatten().expand(z_in.shape[0]) + x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample + v = self.scheduler.velocity_from_prediction(z_in, x_pred, t) + + if not do_classifier_free_guidance: + return v + + v_cond, v_uncond = v.chunk(2, dim=0) + interval_mask = t < guidance_interval_max + if guidance_interval_min != 0.0: + interval_mask = interval_mask & (t > guidance_interval_min) + scale = torch.where( + interval_mask, + torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype), + torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype), + ) + return v_uncond + scale * (v_cond - v_uncond) + + def _run_sampler( + self, + latents: torch.Tensor, + class_labels: torch.Tensor, + class_null: torch.Tensor, + num_inference_steps: int, + do_classifier_free_guidance: bool, + guidance_scale: float, + guidance_interval_min: float, + guidance_interval_max: float, + sampling_method: str, + ) -> torch.Tensor: + device = latents.device + self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method) + timesteps = self.scheduler.timesteps + + for i in self.progress_bar(range(num_inference_steps - 1)): + t = timesteps[i] + t_next = timesteps[i + 1] + v = self._predict_velocity( + latents, + t, + class_labels, + class_null, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + ) + + if sampling_method == "heun": + latents_euler = latents + (t_next - t) * v + v_next = self._predict_velocity( + latents_euler, + t_next, + class_labels, + class_null, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + ) + latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample + else: + latents = self.scheduler.step(v, t, latents).prev_sample + + t = timesteps[-2] + t_next = timesteps[-1] + v = self._predict_velocity( + latents, + t, + class_labels, + class_null, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + ) + return latents + (t_next - t) * v + + @torch.inference_mode() + def __call__( + self, + class_labels: Union[int, str, List[Union[int, str]]], + guidance_scale: Optional[float] = None, + guidance_interval_min: float = 0.1, + guidance_interval_max: float = 1.0, + noise_scale: Optional[float] = None, + t_eps: Optional[float] = None, + sampling_method: Optional[str] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Generate class-conditional images. + + Args: + class_labels (`int`, `str`, `list[int]`, or `list[str]`): + ImageNet class indices or human-readable label strings (English or Chinese). + guidance_scale (`float`, *optional*): + Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`. + guidance_interval_min (`float`, defaults to `0.1`): + Lower bound of the CFG interval in flow time `t in [0, 1]`. + guidance_interval_max (`float`, defaults to `1.0`): + Upper bound of the CFG interval in flow time. + noise_scale (`float`, *optional*): + Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default). + t_eps (`float`, *optional*): + Epsilon clamp for the `1 - t` denominator (scheduler config by default). + sampling_method (`str`, *optional*): + `"heun"` or `"euler"`. Defaults to the scheduler config (`heun`). + generator (`torch.Generator`, *optional*): + RNG for reproducibility. + num_inference_steps (`int`, defaults to `50`): + Number of solver steps (at least 2). + output_type (`str`, *optional*, defaults to `"pil"`): + `"pil"`, `"np"`, or `"pt"`. + return_dict (`bool`, *optional*, defaults to `True`): + Return [`ImagePipelineOutput`] if True. + """ + solver = sampling_method or self.scheduler.config.solver + if solver not in {"heun", "euler"}: + raise ValueError("sampling_method must be one of: 'heun', 'euler'.") + if num_inference_steps < 2: + raise ValueError("num_inference_steps must be >= 2.") + + if t_eps is not None: + self.scheduler.register_to_config(t_eps=t_eps) + + class_label_ids = self._normalize_class_labels(class_labels) + do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0 + + batch_size = len(class_label_ids) + image_size = int(self.transformer.config.sample_size) + channels = int(self.transformer.config.in_channels) + null_class_val = int(self.transformer.config.num_classes) + + if guidance_scale is None: + guidance_scale = 1.0 + if noise_scale is None: + noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0) + + latents = ( + randn_tensor( + shape=(batch_size, channels, image_size, image_size), + generator=generator, + device=self._execution_device, + dtype=self.transformer.dtype, + ) + * noise_scale + ) + + class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1) + class_labels_t = class_labels_t.clamp(0, null_class_val - 1) + class_null = torch.full_like(class_labels_t, null_class_val) + + latents = self._run_sampler( + latents, + class_labels_t, + class_null, + num_inference_steps, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + solver, + ) + + images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu() + if output_type == "pt": + images = images_pt + elif output_type == "np": + images = images_pt.permute(0, 2, 3, 1).numpy() + else: + images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy()) + + self.maybe_free_model_hooks() + + if not return_dict: + return (images,) + return ImagePipelineOutput(images=images) diff --git a/JiT-B-32/scheduler/scheduler_config.json b/JiT-B-32/scheduler/scheduler_config.json new file mode 100644 index 0000000000000000000000000000000000000000..bb871ad8071d8be4699f5246288de0a17963a5c4 --- /dev/null +++ b/JiT-B-32/scheduler/scheduler_config.json @@ -0,0 +1,7 @@ +{ + "_class_name": "JiTScheduler", + "_diffusers_version": "0.36.0", + "num_train_timesteps": 1000, + "t_eps": 0.05, + "solver": "heun" +} diff --git a/JiT-B-32/scheduler/scheduling_jit.py b/JiT-B-32/scheduler/scheduling_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5c55890f3446c190ca847f204264b4b8cbbbbb --- /dev/null +++ b/JiT-B-32/scheduler/scheduling_jit.py @@ -0,0 +1,161 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput + + +@dataclass +class JiTSchedulerOutput(BaseOutput): + """ + Output class for the JiT scheduler's `step` function. + + Args: + prev_sample (`torch.Tensor`): + Updated sample after one solver step along the JiT flow-time grid. + """ + + prev_sample: torch.Tensor + + +class JiTScheduler(SchedulerMixin, ConfigMixin): + """ + Manual flow-matching scheduler for JiT checkpoints. + + Uses a linear flow-time grid `t in [0, 1]` (increasing), matching the official JiT + sampler. Velocity is `v = (x_pred - z) / (1 - t)`; integration is explicit Euler or + Heun along that grid. + """ + + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + t_eps: float = 5e-2, + solver: str = "heun", + ): + if solver not in {"heun", "euler"}: + raise ValueError("solver must be one of: 'heun', 'euler'.") + self.timesteps: Optional[torch.Tensor] = None + self.sigmas: Optional[List[float]] = None + self.num_inference_steps: Optional[int] = None + self._step_index: Optional[int] = None + + @property + def init_noise_sigma(self) -> float: + return 1.0 + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device, None] = None, + solver: Optional[str] = None, + ) -> None: + if num_inference_steps < 2: + raise ValueError("num_inference_steps must be >= 2.") + + self.num_inference_steps = num_inference_steps + self.timesteps = torch.linspace( + 0.0, + 1.0, + num_inference_steps + 1, + device=device, + dtype=torch.float32, + ) + sigma_grid = torch.linspace(0.0, 1.0, num_inference_steps, device=device, dtype=torch.float32) + self.sigmas = (1.0 - sigma_grid).tolist() + self._step_index = 0 + if solver is not None: + self.register_to_config(solver=solver) + + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + del timestep + return sample + + def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int: + if self._step_index is not None: + return self._step_index + if self.timesteps is None: + raise ValueError("Call `set_timesteps` before `step`.") + if timestep is None: + return 0 + t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0]) + matches = (self.timesteps - t_value).abs() < 1e-6 + if matches.any(): + return int(matches.nonzero(as_tuple=False)[0].item()) + return 0 + + def step( + self, + model_output: torch.Tensor, + timestep: Union[float, torch.Tensor, None], + sample: torch.Tensor, + model_output_next: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[JiTSchedulerOutput, Tuple[torch.Tensor]]: + """ + Integrate one step on the linear `t` grid. + + Args: + model_output (`torch.Tensor`): + Velocity `v = (x_pred - z) / (1 - t)` at the current time. + timestep (`float` or `torch.Tensor`, *optional*): + Current flow time `t`. When omitted, uses the internal step index. + sample (`torch.Tensor`): + Current noisy latent `z`. + model_output_next (`torch.Tensor`, *optional*): + Velocity at `t_next` (required for Heun intermediate steps). + """ + if self.timesteps is None: + raise ValueError("Call `set_timesteps` before `step`.") + + step_index = self._resolve_step_index(timestep) + if step_index >= len(self.timesteps) - 1: + raise ValueError("Scheduler has already reached the final timestep.") + + t = self.timesteps[step_index] + t_next = self.timesteps[step_index + 1] + dt = t_next - t + + if self.config.solver == "heun" and model_output_next is not None: + prev_sample = sample + dt * 0.5 * (model_output + model_output_next) + else: + prev_sample = sample + dt * model_output + + self._step_index = step_index + 1 + + if not return_dict: + return (prev_sample,) + return JiTSchedulerOutput(prev_sample=prev_sample) + + def velocity_from_prediction( + self, + sample: torch.Tensor, + x_pred: torch.Tensor, + timestep: Union[float, torch.Tensor], + ) -> torch.Tensor: + """Compute JiT velocity `v = (x_pred - z) / (1 - t)` with denominator clamp.""" + t = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype) + while t.ndim < sample.ndim: + t = t.unsqueeze(-1) + denom = (1.0 - t).clamp_min(self.config.t_eps) + return (x_pred - sample) / denom diff --git a/JiT-B-32/transformer/config.json b/JiT-B-32/transformer/config.json new file mode 100644 index 0000000000000000000000000000000000000000..412ffcf8b2aa709a6fd427a1944ab6088b6c8e7d --- /dev/null +++ b/JiT-B-32/transformer/config.json @@ -0,0 +1,18 @@ +{ + "_class_name": "JiTTransformer2DModel", + "_diffusers_version": "0.36.0", + "attention_dropout": 0.0, + "bottleneck_dim": 128, + "dropout": 0.0, + "hidden_size": 768, + "in_channels": 3, + "in_context_len": 32, + "in_context_start": 4, + "mlp_ratio": 4.0, + "norm_eps": 1e-06, + "num_attention_heads": 12, + "num_classes": 1000, + "num_layers": 12, + "patch_size": 32, + "sample_size": 512 +} diff --git a/JiT-B-32/transformer/diffusion_pytorch_model.safetensors b/JiT-B-32/transformer/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..648c81efb38c6cabd5c9a2080d802d0fcd8880b0 --- /dev/null +++ b/JiT-B-32/transformer/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:729654b3302fdae22eb4a4de9d2b24545828c82f2e2c8dcd3f5a01fe7c606ba4 +size 533565560 diff --git a/JiT-B-32/transformer/jit_transformer_2d.py b/JiT-B-32/transformer/jit_transformer_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..3af0b8632931f4d42d78f8f9ced62d868e070e43 --- /dev/null +++ b/JiT-B-32/transformer/jit_transformer_2d.py @@ -0,0 +1,500 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import RMSNorm +from diffusers.utils import logging + +logger = logging.get_logger(__name__) + + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = {len(t.shape) for t in tensors} + if len(shape_lens) != 1: + raise ValueError("tensors must all have the same number of dimensions") + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*(list(t.shape) for t in tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + + if not all(len(set(t[1])) <= 2 for t in expandable_dims): + raise ValueError("invalid dimensions for broadcastable concatenation") + + max_dims = [(t[0], max(t[1])) for t in expandable_dims] + expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims] + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*(t[1] for t in expanded_dims))) + tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)] + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return x.view(*x.shape[:-2], -1) + + +class JiTRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + pt_seq_len=16, + ft_seq_len=None, + custom_freqs=None, + theta=10000, + num_cls_token=0, + ): + super().__init__() + if custom_freqs is not None: + freqs = custom_freqs + else: + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs = torch.einsum("..., f -> ... f", t, freqs) + freqs = freqs.repeat_interleave(2, dim=-1) + freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) + + if num_cls_token > 0: + freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D] + cos_img = freqs_flat.cos() + sin_img = freqs_flat.sin() + + # prepend in-context cls token + _, D = cos_img.shape + cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype) + sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype) + + self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False) + self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False) + else: + self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False) + self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False) + + def forward(self, t): + # Applied on (batch, seq_len, heads, head_dim) tensors from attention. + seq_len = t.shape[1] + freqs_cos = self.freqs_cos[:seq_len].to(t.dtype) + freqs_sin = self.freqs_sin[:seq_len].to(t.dtype) + + return t * freqs_cos[:, None, :] + rotate_half(t) * freqs_sin[:, None, :] + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class JiTPatchEmbed(nn.Module): + """Image to Patch Embedding with Bottleneck""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + + self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False) + self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias) + + def forward(self, x): + x = self.proj2(self.proj1(x)).flatten(2).transpose(1, 2) + return x + + +class JiTTimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t, dtype=None): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + if dtype is not None: + t_freq = t_freq.to(dtype=dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class JiTLabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. + """ + + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes + 1, hidden_size) + self.num_classes = num_classes + + def forward(self, labels): + embeddings = self.embedding_table(labels) + return embeddings + + +class JiTAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0.0, proj_drop=0.0, eps=1e-6): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.q_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + self.k_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rope=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = self.q_norm(q) + k = self.k_norm(k) + + if rope is not None: + q = q.transpose(1, 2) + k = k.transpose(1, 2) + q = rope(q) + k = rope(k) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + dropout_p = self.attn_drop if self.training else 0.0 + x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class JiTSwiGLUFFN(nn.Module): + def __init__(self, dim: int, hidden_dim: int, drop=0.0, bias=True) -> None: + super().__init__() + hidden_dim = int(hidden_dim * 2 / 3) + self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias) + self.w3 = nn.Linear(hidden_dim, dim, bias=bias) + self.ffn_dropout = nn.Dropout(drop) + + def forward(self, x): + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(self.ffn_dropout(hidden)) + + +class JiTBlock(nn.Module): + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, eps=1e-6): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=eps) + self.attn = JiTAttention( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + qk_norm=True, + attn_drop=attn_drop, + proj_drop=proj_drop, + eps=eps, + ) + self.norm2 = RMSNorm(hidden_size, eps=eps) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = JiTSwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop) + + self.act = nn.SiLU() + self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True) + + def forward(self, x, c, feat_rope=None): + # Apply activation + c = self.act(c) + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + + # Attention block + norm_x = self.norm1(x) + modulated_x = modulate(norm_x, shift_msa, scale_msa) + attn_out = self.attn(modulated_x, rope=feat_rope) + x = x + gate_msa.unsqueeze(1) * attn_out + + # MLP block + norm_x = self.norm2(x) + modulated_x = modulate(norm_x, shift_mlp, scale_mlp) + mlp_out = self.mlp(modulated_x) + x = x + gate_mlp.unsqueeze(1) * mlp_out + + return x + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0) + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}") + + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) + emb = np.concatenate([emb_h, emb_w], axis=1) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + if embed_dim % 2 != 0: + raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega + + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) + + emb_sin = np.sin(out) + emb_cos = np.cos(out) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) + return emb + + +class JiTTransformer2DModel(ModelMixin, ConfigMixin): + r""" + A 2D Transformer for pixel-space class-conditional generation with JiT + ([Back to Basics: Let Denoising Generative Models Denoise](https://arxiv.org/abs/2511.13720)). + + Parameters: + sample_size (`int`, defaults to `256`): + Input image resolution (height and width). + patch_size (`int`, defaults to `16`): + Patch size for the bottleneck patch embedder. + in_channels (`int`, defaults to `3`): + Number of input image channels. + hidden_size (`int`, defaults to `768`): + Transformer hidden dimension. + num_layers (`int`, defaults to `12`): + Number of JiT transformer blocks. + num_attention_heads (`int`, defaults to `12`): + Number of attention heads per block. + mlp_ratio (`float`, defaults to `4.0`): + MLP hidden dimension multiplier. + attention_dropout (`float`, defaults to `0.0`): + Attention dropout in the middle quarter of blocks. + dropout (`float`, defaults to `0.0`): + Projection dropout in the middle quarter of blocks. + num_classes (`int`, defaults to `1000`): + Number of class labels (null label uses index `num_classes` for CFG). + bottleneck_dim (`int`, defaults to `128`): + PCA bottleneck dimension in the patch embedder. + in_context_len (`int`, defaults to `32`): + Number of in-context class tokens prepended mid-network. + in_context_start (`int`, defaults to `4`): + Block index at which in-context tokens are inserted. + norm_eps (`float`, defaults to `1e-6`): + Epsilon for RMSNorm layers. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + + @register_to_config + def __init__( + self, + sample_size: int = 256, + patch_size: int = 16, + in_channels: int = 3, + hidden_size: int = 768, + num_layers: int = 12, + num_attention_heads: int = 12, + mlp_ratio: float = 4.0, + attention_dropout: float = 0.0, + dropout: float = 0.0, + num_classes: int = 1000, + bottleneck_dim: int = 128, + in_context_len: int = 32, + in_context_start: int = 4, + norm_eps: float = 1e-6, + ): + super().__init__() + self.sample_size = sample_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.in_context_len = in_context_len + self.in_context_start = in_context_start + self.norm_eps = norm_eps + self.gradient_checkpointing = False + + # Time and Class Embedding + self.t_embedder = JiTTimestepEmbedder(hidden_size) + self.y_embedder = JiTLabelEmbedder(num_classes, hidden_size) + + # Patch Embedding + self.x_embedder = JiTPatchEmbed( + img_size=sample_size, + patch_size=patch_size, + in_chans=in_channels, + pca_dim=bottleneck_dim, + embed_dim=hidden_size, + bias=True, + ) + + # Positional Embedding (Fixed Sin-Cos) + num_patches = self.x_embedder.num_patches + pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches**0.5)) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True) + + # In-context Embedding + if self.in_context_len > 0: + self.in_context_posemb = nn.Parameter(torch.zeros(1, self.in_context_len, hidden_size)) + + # RoPE + half_head_dim = hidden_size // num_attention_heads // 2 + hw_seq_len = sample_size // patch_size + self.feat_rope = JiTRotaryEmbedding(dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=0) + self.feat_rope_incontext = JiTRotaryEmbedding( + dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=self.in_context_len + ) + + # Blocks + self.blocks = nn.ModuleList( + [ + JiTBlock( + hidden_size, + num_attention_heads, + mlp_ratio=mlp_ratio, + attn_drop=attention_dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0, + proj_drop=dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0, + eps=norm_eps, + ) + for i in range(num_layers) + ] + ) + + # Final Layer + self.norm_final = RMSNorm(hidden_size, eps=norm_eps) + self.linear_final = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True) + self.act_final = nn.SiLU() + self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + class_labels: torch.LongTensor, + return_dict: bool = True, + ): + + t_emb = self.t_embedder(timestep, dtype=hidden_states.dtype) + y_emb = self.y_embedder(class_labels) + + # Ensure embeddings match hidden_states dtype + y_emb = y_emb.to(dtype=hidden_states.dtype) + + c = t_emb + y_emb + + # Patch Embed + x = self.x_embedder(hidden_states) + x = x + self.pos_embed.to(x.dtype) + + # Blocks + for i, block in enumerate(self.blocks): + if self.in_context_len > 0 and i == self.in_context_start: + in_context_tokens = y_emb.unsqueeze(1).repeat(1, self.in_context_len, 1) + in_context_tokens = in_context_tokens + self.in_context_posemb.to(in_context_tokens.dtype) + x = torch.cat([in_context_tokens, x], dim=1) + + rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext + + if self.training and self.gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + block, + x, + c, + rope, + use_reentrant=False, + ) + else: + x = block(x, c, feat_rope=rope) + + # Slice off in-context tokens + if self.in_context_len > 0: + x = x[:, self.in_context_len :] + + # Final Layer + c = self.act_final(c) + shift, scale = self.adaLN_modulation_final(c).chunk(2, dim=1) + + x = modulate(self.norm_final(x), shift, scale) + x = self.linear_final(x) + + # Unpatchify + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels)) + x = torch.einsum("nhwpqc->nchpwq", x) + output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size)) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/JiT-H-16/model_index.json b/JiT-H-16/model_index.json index 20edd748a56adb768b31321efe2e2a1855c71ab3..fa18cbcc32203c64fd174626ff563c5f533fb945 100644 --- a/JiT-H-16/model_index.json +++ b/JiT-H-16/model_index.json @@ -1,8 +1,15 @@ { - "_class_name": "JiTPipeline", + "_class_name": [ + "pipeline", + "JiTPipeline" + ], "_diffusers_version": "0.36.0", + "scheduler": [ + "scheduling_jit", + "JiTScheduler" + ], "transformer": [ - "jit_diffusers", + "jit_transformer_2d", "JiTTransformer2DModel" ] } diff --git a/JiT-H-16/pipeline.py b/JiT-H-16/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..6196a7db5f800a02de65d7b100cf3474cc67dcf7 --- /dev/null +++ b/JiT-H-16/pipeline.py @@ -0,0 +1,460 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib +import json +import sys +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import torch + +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.utils.torch_utils import randn_tensor + + +RECOMMENDED_NOISE_BY_SIZE = { + 256: 1.0, + 512: 2.0, +} + + +class JiTPipeline(DiffusionPipeline): + r""" + Pipeline for image generation using JiT (Just image Transformer). + + Parameters: + transformer ([`JiTTransformer2DModel`]): + A class-conditioned `JiTTransformer2DModel` to denoise the images. + scheduler ([`JiTScheduler`]): + Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler). + id2label (`dict[int, str]`, *optional*): + ImageNet class id to English label mapping. Values may contain comma-separated synonyms. + id2label_cn (`dict[int, str]`, *optional*): + ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms. + """ + + model_cpu_offload_seq = "transformer" + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs): + """Load a self-contained variant folder locally or from the Hub. + + Examples: + JiTPipeline.from_pretrained(".") + JiTPipeline.from_pretrained("./JiT-H-32") + DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True) + """ + repo_root = Path(__file__).resolve().parent + + if pretrained_model_name_or_path in (None, "", "."): + variant = repo_root + elif ( + isinstance(pretrained_model_name_or_path, str) + and "/" in pretrained_model_name_or_path + and not Path(pretrained_model_name_or_path).exists() + ): + from huggingface_hub import snapshot_download + + hub_kwargs = dict(kwargs.pop("hub_kwargs", {})) + if subfolder: + hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"]) + cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs) + variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir) + else: + variant = Path(pretrained_model_name_or_path) + if not variant.is_absolute(): + candidate = (Path.cwd() / variant).resolve() + variant = candidate if candidate.exists() else (repo_root / variant).resolve() + if subfolder: + variant = variant / subfolder + + model_kwargs = dict(kwargs) + inserted: List[str] = [] + + def _load_component(folder: str, module_name: str, class_name: str): + comp_dir = variant / folder + module_path = comp_dir / f"{module_name}.py" + has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists() + if not module_path.exists() or not has_weights: + return None + + comp_path = str(comp_dir) + if comp_path not in sys.path: + sys.path.insert(0, comp_path) + inserted.append(comp_path) + + module = importlib.import_module(module_name) + component_cls = getattr(module, class_name) + return component_cls.from_pretrained(str(comp_dir), **model_kwargs) + + try: + transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel") + scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler") + + if transformer is None: + raise ValueError(f"No loadable transformer found under {variant}") + + variant_path = str(variant) + id2label, id2label_cn = cls._load_labels_for_variant(variant_path) + + pipe = cls( + transformer=transformer, + scheduler=scheduler, + id2label=id2label, + id2label_cn=id2label_cn, + ) + if variant_path and hasattr(pipe, "register_to_config"): + pipe.register_to_config(_name_or_path=variant_path) + return pipe + finally: + for comp_path in inserted: + if comp_path in sys.path: + sys.path.remove(comp_path) + + def __init__( + self, + transformer, + scheduler, + id2label: Optional[Dict[int, str]] = None, + id2label_cn: Optional[Dict[int, str]] = None, + ): + super().__init__() + self.register_modules(transformer=transformer, scheduler=scheduler) + + self._id2label = id2label or {} + self._id2label_cn = id2label_cn or {} + self.labels = self._build_label2id(self._id2label) + self.labels_cn = self._build_label2id(self._id2label_cn) + + def _ensure_labels_loaded(self) -> None: + if self._id2label or self._id2label_cn: + return + loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None)) + if loaded_en: + self._id2label = loaded_en + self.labels = self._build_label2id(self._id2label) + if loaded_cn: + self._id2label_cn = loaded_cn + self.labels_cn = self._build_label2id(self._id2label_cn) + + @staticmethod + def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]: + if not variant_path: + return None + variant_dir = Path(variant_path).resolve() + labels_dir = variant_dir.parent / "labels" + return labels_dir if labels_dir.is_dir() else None + + @staticmethod + def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]: + filename = "id2label_en.json" if lang == "en" else "id2label_cn.json" + path = labels_dir / filename + if not path.exists(): + raise FileNotFoundError(path) + raw = json.loads(path.read_text(encoding="utf-8")) + return {int(key): value for key, value in raw.items()} + + @classmethod + def _load_labels_for_variant( + cls, + variant_path: Optional[str], + ) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]: + labels_dir = cls._labels_dir_for_variant(variant_path) + if labels_dir is None: + return None, None + try: + return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn") + except FileNotFoundError: + return None, None + + @staticmethod + def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]: + label2id: Dict[str, int] = {} + for class_id, value in id2label.items(): + for synonym in value.split(","): + synonym = synonym.strip() + if synonym: + label2id[synonym] = int(class_id) + return dict(sorted(label2id.items())) + + @property + def id2label(self) -> Dict[int, str]: + """ImageNet class id to English label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label + + @property + def id2label_cn(self) -> Dict[int, str]: + """ImageNet class id to Chinese label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label_cn + + def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]: + r""" + Map ImageNet label strings to class ids. + + Args: + label (`str` or `list[str]`): + One or more label strings. Each string must match a synonym in `id2label` (English) + or `id2label_cn` (Chinese). + lang (`str`, *optional*, defaults to `"en"`): + `"en"` uses English synonyms; `"cn"` uses Chinese synonyms. + """ + if lang not in ("en", "cn"): + raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.") + + self._ensure_labels_loaded() + label2id = self.labels if lang == "en" else self.labels_cn + if not label2id: + raise ValueError( + f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder." + ) + + if isinstance(label, str): + label = [label] + + missing = [item for item in label if item not in label2id] + if missing: + preview = ", ".join(list(label2id.keys())[:8]) + raise ValueError( + f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..." + ) + return [label2id[item] for item in label] + + def _normalize_class_labels( + self, + class_labels: Union[int, str, List[Union[int, str]]], + ) -> List[int]: + if isinstance(class_labels, int): + return [class_labels] + + if isinstance(class_labels, str): + return self.get_label_ids(class_labels) + + if class_labels and isinstance(class_labels[0], str): + self._ensure_labels_loaded() + if all(label in self.labels for label in class_labels): + return self.get_label_ids(class_labels, lang="en") + if all(label in self.labels_cn for label in class_labels): + return self.get_label_ids(class_labels, lang="cn") + raise ValueError( + "Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` " + "or Chinese synonyms from `pipe.labels_cn`." + ) + + return list(class_labels) + + def _predict_velocity( + self, + z_value: torch.Tensor, + t: torch.Tensor, + class_labels: torch.Tensor, + class_null: torch.Tensor, + do_classifier_free_guidance: bool, + guidance_scale: float, + guidance_interval_min: float, + guidance_interval_max: float, + ) -> torch.Tensor: + t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype) + if do_classifier_free_guidance: + z_in = torch.cat([z_value, z_value], dim=0) + labels = torch.cat([class_labels, class_null], dim=0) + else: + z_in = z_value + labels = class_labels + + t_batch = t.flatten().expand(z_in.shape[0]) + x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample + v = self.scheduler.velocity_from_prediction(z_in, x_pred, t) + + if not do_classifier_free_guidance: + return v + + v_cond, v_uncond = v.chunk(2, dim=0) + interval_mask = t < guidance_interval_max + if guidance_interval_min != 0.0: + interval_mask = interval_mask & (t > guidance_interval_min) + scale = torch.where( + interval_mask, + torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype), + torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype), + ) + return v_uncond + scale * (v_cond - v_uncond) + + def _run_sampler( + self, + latents: torch.Tensor, + class_labels: torch.Tensor, + class_null: torch.Tensor, + num_inference_steps: int, + do_classifier_free_guidance: bool, + guidance_scale: float, + guidance_interval_min: float, + guidance_interval_max: float, + sampling_method: str, + ) -> torch.Tensor: + device = latents.device + self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method) + timesteps = self.scheduler.timesteps + + for i in self.progress_bar(range(num_inference_steps - 1)): + t = timesteps[i] + t_next = timesteps[i + 1] + v = self._predict_velocity( + latents, + t, + class_labels, + class_null, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + ) + + if sampling_method == "heun": + latents_euler = latents + (t_next - t) * v + v_next = self._predict_velocity( + latents_euler, + t_next, + class_labels, + class_null, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + ) + latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample + else: + latents = self.scheduler.step(v, t, latents).prev_sample + + t = timesteps[-2] + t_next = timesteps[-1] + v = self._predict_velocity( + latents, + t, + class_labels, + class_null, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + ) + return latents + (t_next - t) * v + + @torch.inference_mode() + def __call__( + self, + class_labels: Union[int, str, List[Union[int, str]]], + guidance_scale: Optional[float] = None, + guidance_interval_min: float = 0.1, + guidance_interval_max: float = 1.0, + noise_scale: Optional[float] = None, + t_eps: Optional[float] = None, + sampling_method: Optional[str] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Generate class-conditional images. + + Args: + class_labels (`int`, `str`, `list[int]`, or `list[str]`): + ImageNet class indices or human-readable label strings (English or Chinese). + guidance_scale (`float`, *optional*): + Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`. + guidance_interval_min (`float`, defaults to `0.1`): + Lower bound of the CFG interval in flow time `t in [0, 1]`. + guidance_interval_max (`float`, defaults to `1.0`): + Upper bound of the CFG interval in flow time. + noise_scale (`float`, *optional*): + Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default). + t_eps (`float`, *optional*): + Epsilon clamp for the `1 - t` denominator (scheduler config by default). + sampling_method (`str`, *optional*): + `"heun"` or `"euler"`. Defaults to the scheduler config (`heun`). + generator (`torch.Generator`, *optional*): + RNG for reproducibility. + num_inference_steps (`int`, defaults to `50`): + Number of solver steps (at least 2). + output_type (`str`, *optional*, defaults to `"pil"`): + `"pil"`, `"np"`, or `"pt"`. + return_dict (`bool`, *optional*, defaults to `True`): + Return [`ImagePipelineOutput`] if True. + """ + solver = sampling_method or self.scheduler.config.solver + if solver not in {"heun", "euler"}: + raise ValueError("sampling_method must be one of: 'heun', 'euler'.") + if num_inference_steps < 2: + raise ValueError("num_inference_steps must be >= 2.") + + if t_eps is not None: + self.scheduler.register_to_config(t_eps=t_eps) + + class_label_ids = self._normalize_class_labels(class_labels) + do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0 + + batch_size = len(class_label_ids) + image_size = int(self.transformer.config.sample_size) + channels = int(self.transformer.config.in_channels) + null_class_val = int(self.transformer.config.num_classes) + + if guidance_scale is None: + guidance_scale = 1.0 + if noise_scale is None: + noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0) + + latents = ( + randn_tensor( + shape=(batch_size, channels, image_size, image_size), + generator=generator, + device=self._execution_device, + dtype=self.transformer.dtype, + ) + * noise_scale + ) + + class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1) + class_labels_t = class_labels_t.clamp(0, null_class_val - 1) + class_null = torch.full_like(class_labels_t, null_class_val) + + latents = self._run_sampler( + latents, + class_labels_t, + class_null, + num_inference_steps, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + solver, + ) + + images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu() + if output_type == "pt": + images = images_pt + elif output_type == "np": + images = images_pt.permute(0, 2, 3, 1).numpy() + else: + images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy()) + + self.maybe_free_model_hooks() + + if not return_dict: + return (images,) + return ImagePipelineOutput(images=images) diff --git a/JiT-H-16/scheduler/scheduler_config.json b/JiT-H-16/scheduler/scheduler_config.json new file mode 100644 index 0000000000000000000000000000000000000000..bb871ad8071d8be4699f5246288de0a17963a5c4 --- /dev/null +++ b/JiT-H-16/scheduler/scheduler_config.json @@ -0,0 +1,7 @@ +{ + "_class_name": "JiTScheduler", + "_diffusers_version": "0.36.0", + "num_train_timesteps": 1000, + "t_eps": 0.05, + "solver": "heun" +} diff --git a/JiT-H-16/scheduler/scheduling_jit.py b/JiT-H-16/scheduler/scheduling_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5c55890f3446c190ca847f204264b4b8cbbbbb --- /dev/null +++ b/JiT-H-16/scheduler/scheduling_jit.py @@ -0,0 +1,161 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput + + +@dataclass +class JiTSchedulerOutput(BaseOutput): + """ + Output class for the JiT scheduler's `step` function. + + Args: + prev_sample (`torch.Tensor`): + Updated sample after one solver step along the JiT flow-time grid. + """ + + prev_sample: torch.Tensor + + +class JiTScheduler(SchedulerMixin, ConfigMixin): + """ + Manual flow-matching scheduler for JiT checkpoints. + + Uses a linear flow-time grid `t in [0, 1]` (increasing), matching the official JiT + sampler. Velocity is `v = (x_pred - z) / (1 - t)`; integration is explicit Euler or + Heun along that grid. + """ + + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + t_eps: float = 5e-2, + solver: str = "heun", + ): + if solver not in {"heun", "euler"}: + raise ValueError("solver must be one of: 'heun', 'euler'.") + self.timesteps: Optional[torch.Tensor] = None + self.sigmas: Optional[List[float]] = None + self.num_inference_steps: Optional[int] = None + self._step_index: Optional[int] = None + + @property + def init_noise_sigma(self) -> float: + return 1.0 + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device, None] = None, + solver: Optional[str] = None, + ) -> None: + if num_inference_steps < 2: + raise ValueError("num_inference_steps must be >= 2.") + + self.num_inference_steps = num_inference_steps + self.timesteps = torch.linspace( + 0.0, + 1.0, + num_inference_steps + 1, + device=device, + dtype=torch.float32, + ) + sigma_grid = torch.linspace(0.0, 1.0, num_inference_steps, device=device, dtype=torch.float32) + self.sigmas = (1.0 - sigma_grid).tolist() + self._step_index = 0 + if solver is not None: + self.register_to_config(solver=solver) + + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + del timestep + return sample + + def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int: + if self._step_index is not None: + return self._step_index + if self.timesteps is None: + raise ValueError("Call `set_timesteps` before `step`.") + if timestep is None: + return 0 + t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0]) + matches = (self.timesteps - t_value).abs() < 1e-6 + if matches.any(): + return int(matches.nonzero(as_tuple=False)[0].item()) + return 0 + + def step( + self, + model_output: torch.Tensor, + timestep: Union[float, torch.Tensor, None], + sample: torch.Tensor, + model_output_next: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[JiTSchedulerOutput, Tuple[torch.Tensor]]: + """ + Integrate one step on the linear `t` grid. + + Args: + model_output (`torch.Tensor`): + Velocity `v = (x_pred - z) / (1 - t)` at the current time. + timestep (`float` or `torch.Tensor`, *optional*): + Current flow time `t`. When omitted, uses the internal step index. + sample (`torch.Tensor`): + Current noisy latent `z`. + model_output_next (`torch.Tensor`, *optional*): + Velocity at `t_next` (required for Heun intermediate steps). + """ + if self.timesteps is None: + raise ValueError("Call `set_timesteps` before `step`.") + + step_index = self._resolve_step_index(timestep) + if step_index >= len(self.timesteps) - 1: + raise ValueError("Scheduler has already reached the final timestep.") + + t = self.timesteps[step_index] + t_next = self.timesteps[step_index + 1] + dt = t_next - t + + if self.config.solver == "heun" and model_output_next is not None: + prev_sample = sample + dt * 0.5 * (model_output + model_output_next) + else: + prev_sample = sample + dt * model_output + + self._step_index = step_index + 1 + + if not return_dict: + return (prev_sample,) + return JiTSchedulerOutput(prev_sample=prev_sample) + + def velocity_from_prediction( + self, + sample: torch.Tensor, + x_pred: torch.Tensor, + timestep: Union[float, torch.Tensor], + ) -> torch.Tensor: + """Compute JiT velocity `v = (x_pred - z) / (1 - t)` with denominator clamp.""" + t = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype) + while t.ndim < sample.ndim: + t = t.unsqueeze(-1) + denom = (1.0 - t).clamp_min(self.config.t_eps) + return (x_pred - sample) / denom diff --git a/JiT-H-16/transformer/config.json b/JiT-H-16/transformer/config.json new file mode 100644 index 0000000000000000000000000000000000000000..ae4bf474c561e132cd3b5d3336c755fa7ae5bda6 --- /dev/null +++ b/JiT-H-16/transformer/config.json @@ -0,0 +1,18 @@ +{ + "_class_name": "JiTTransformer2DModel", + "_diffusers_version": "0.36.0", + "attention_dropout": 0.0, + "bottleneck_dim": 256, + "dropout": 0.2, + "hidden_size": 1280, + "in_channels": 3, + "in_context_len": 32, + "in_context_start": 10, + "mlp_ratio": 4.0, + "norm_eps": 1e-06, + "num_attention_heads": 16, + "num_classes": 1000, + "num_layers": 32, + "patch_size": 16, + "sample_size": 256 +} diff --git a/JiT-H-16/transformer/diffusion_pytorch_model.safetensors b/JiT-H-16/transformer/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..4bbc040e55eeaeda48d0831bd3d8c7ddf204c2af --- /dev/null +++ b/JiT-H-16/transformer/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b6ad4cf51f5ff385db58573a23353b50df4be7a63dd50bdc7b57af404e7b68e7 +size 3811413928 diff --git a/JiT-H-16/transformer/jit_transformer_2d.py b/JiT-H-16/transformer/jit_transformer_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..3af0b8632931f4d42d78f8f9ced62d868e070e43 --- /dev/null +++ b/JiT-H-16/transformer/jit_transformer_2d.py @@ -0,0 +1,500 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import RMSNorm +from diffusers.utils import logging + +logger = logging.get_logger(__name__) + + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = {len(t.shape) for t in tensors} + if len(shape_lens) != 1: + raise ValueError("tensors must all have the same number of dimensions") + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*(list(t.shape) for t in tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + + if not all(len(set(t[1])) <= 2 for t in expandable_dims): + raise ValueError("invalid dimensions for broadcastable concatenation") + + max_dims = [(t[0], max(t[1])) for t in expandable_dims] + expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims] + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*(t[1] for t in expanded_dims))) + tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)] + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return x.view(*x.shape[:-2], -1) + + +class JiTRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + pt_seq_len=16, + ft_seq_len=None, + custom_freqs=None, + theta=10000, + num_cls_token=0, + ): + super().__init__() + if custom_freqs is not None: + freqs = custom_freqs + else: + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs = torch.einsum("..., f -> ... f", t, freqs) + freqs = freqs.repeat_interleave(2, dim=-1) + freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) + + if num_cls_token > 0: + freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D] + cos_img = freqs_flat.cos() + sin_img = freqs_flat.sin() + + # prepend in-context cls token + _, D = cos_img.shape + cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype) + sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype) + + self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False) + self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False) + else: + self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False) + self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False) + + def forward(self, t): + # Applied on (batch, seq_len, heads, head_dim) tensors from attention. + seq_len = t.shape[1] + freqs_cos = self.freqs_cos[:seq_len].to(t.dtype) + freqs_sin = self.freqs_sin[:seq_len].to(t.dtype) + + return t * freqs_cos[:, None, :] + rotate_half(t) * freqs_sin[:, None, :] + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class JiTPatchEmbed(nn.Module): + """Image to Patch Embedding with Bottleneck""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + + self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False) + self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias) + + def forward(self, x): + x = self.proj2(self.proj1(x)).flatten(2).transpose(1, 2) + return x + + +class JiTTimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t, dtype=None): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + if dtype is not None: + t_freq = t_freq.to(dtype=dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class JiTLabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. + """ + + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes + 1, hidden_size) + self.num_classes = num_classes + + def forward(self, labels): + embeddings = self.embedding_table(labels) + return embeddings + + +class JiTAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0.0, proj_drop=0.0, eps=1e-6): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.q_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + self.k_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rope=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = self.q_norm(q) + k = self.k_norm(k) + + if rope is not None: + q = q.transpose(1, 2) + k = k.transpose(1, 2) + q = rope(q) + k = rope(k) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + dropout_p = self.attn_drop if self.training else 0.0 + x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class JiTSwiGLUFFN(nn.Module): + def __init__(self, dim: int, hidden_dim: int, drop=0.0, bias=True) -> None: + super().__init__() + hidden_dim = int(hidden_dim * 2 / 3) + self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias) + self.w3 = nn.Linear(hidden_dim, dim, bias=bias) + self.ffn_dropout = nn.Dropout(drop) + + def forward(self, x): + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(self.ffn_dropout(hidden)) + + +class JiTBlock(nn.Module): + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, eps=1e-6): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=eps) + self.attn = JiTAttention( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + qk_norm=True, + attn_drop=attn_drop, + proj_drop=proj_drop, + eps=eps, + ) + self.norm2 = RMSNorm(hidden_size, eps=eps) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = JiTSwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop) + + self.act = nn.SiLU() + self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True) + + def forward(self, x, c, feat_rope=None): + # Apply activation + c = self.act(c) + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + + # Attention block + norm_x = self.norm1(x) + modulated_x = modulate(norm_x, shift_msa, scale_msa) + attn_out = self.attn(modulated_x, rope=feat_rope) + x = x + gate_msa.unsqueeze(1) * attn_out + + # MLP block + norm_x = self.norm2(x) + modulated_x = modulate(norm_x, shift_mlp, scale_mlp) + mlp_out = self.mlp(modulated_x) + x = x + gate_mlp.unsqueeze(1) * mlp_out + + return x + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0) + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}") + + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) + emb = np.concatenate([emb_h, emb_w], axis=1) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + if embed_dim % 2 != 0: + raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega + + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) + + emb_sin = np.sin(out) + emb_cos = np.cos(out) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) + return emb + + +class JiTTransformer2DModel(ModelMixin, ConfigMixin): + r""" + A 2D Transformer for pixel-space class-conditional generation with JiT + ([Back to Basics: Let Denoising Generative Models Denoise](https://arxiv.org/abs/2511.13720)). + + Parameters: + sample_size (`int`, defaults to `256`): + Input image resolution (height and width). + patch_size (`int`, defaults to `16`): + Patch size for the bottleneck patch embedder. + in_channels (`int`, defaults to `3`): + Number of input image channels. + hidden_size (`int`, defaults to `768`): + Transformer hidden dimension. + num_layers (`int`, defaults to `12`): + Number of JiT transformer blocks. + num_attention_heads (`int`, defaults to `12`): + Number of attention heads per block. + mlp_ratio (`float`, defaults to `4.0`): + MLP hidden dimension multiplier. + attention_dropout (`float`, defaults to `0.0`): + Attention dropout in the middle quarter of blocks. + dropout (`float`, defaults to `0.0`): + Projection dropout in the middle quarter of blocks. + num_classes (`int`, defaults to `1000`): + Number of class labels (null label uses index `num_classes` for CFG). + bottleneck_dim (`int`, defaults to `128`): + PCA bottleneck dimension in the patch embedder. + in_context_len (`int`, defaults to `32`): + Number of in-context class tokens prepended mid-network. + in_context_start (`int`, defaults to `4`): + Block index at which in-context tokens are inserted. + norm_eps (`float`, defaults to `1e-6`): + Epsilon for RMSNorm layers. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + + @register_to_config + def __init__( + self, + sample_size: int = 256, + patch_size: int = 16, + in_channels: int = 3, + hidden_size: int = 768, + num_layers: int = 12, + num_attention_heads: int = 12, + mlp_ratio: float = 4.0, + attention_dropout: float = 0.0, + dropout: float = 0.0, + num_classes: int = 1000, + bottleneck_dim: int = 128, + in_context_len: int = 32, + in_context_start: int = 4, + norm_eps: float = 1e-6, + ): + super().__init__() + self.sample_size = sample_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.in_context_len = in_context_len + self.in_context_start = in_context_start + self.norm_eps = norm_eps + self.gradient_checkpointing = False + + # Time and Class Embedding + self.t_embedder = JiTTimestepEmbedder(hidden_size) + self.y_embedder = JiTLabelEmbedder(num_classes, hidden_size) + + # Patch Embedding + self.x_embedder = JiTPatchEmbed( + img_size=sample_size, + patch_size=patch_size, + in_chans=in_channels, + pca_dim=bottleneck_dim, + embed_dim=hidden_size, + bias=True, + ) + + # Positional Embedding (Fixed Sin-Cos) + num_patches = self.x_embedder.num_patches + pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches**0.5)) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True) + + # In-context Embedding + if self.in_context_len > 0: + self.in_context_posemb = nn.Parameter(torch.zeros(1, self.in_context_len, hidden_size)) + + # RoPE + half_head_dim = hidden_size // num_attention_heads // 2 + hw_seq_len = sample_size // patch_size + self.feat_rope = JiTRotaryEmbedding(dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=0) + self.feat_rope_incontext = JiTRotaryEmbedding( + dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=self.in_context_len + ) + + # Blocks + self.blocks = nn.ModuleList( + [ + JiTBlock( + hidden_size, + num_attention_heads, + mlp_ratio=mlp_ratio, + attn_drop=attention_dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0, + proj_drop=dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0, + eps=norm_eps, + ) + for i in range(num_layers) + ] + ) + + # Final Layer + self.norm_final = RMSNorm(hidden_size, eps=norm_eps) + self.linear_final = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True) + self.act_final = nn.SiLU() + self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + class_labels: torch.LongTensor, + return_dict: bool = True, + ): + + t_emb = self.t_embedder(timestep, dtype=hidden_states.dtype) + y_emb = self.y_embedder(class_labels) + + # Ensure embeddings match hidden_states dtype + y_emb = y_emb.to(dtype=hidden_states.dtype) + + c = t_emb + y_emb + + # Patch Embed + x = self.x_embedder(hidden_states) + x = x + self.pos_embed.to(x.dtype) + + # Blocks + for i, block in enumerate(self.blocks): + if self.in_context_len > 0 and i == self.in_context_start: + in_context_tokens = y_emb.unsqueeze(1).repeat(1, self.in_context_len, 1) + in_context_tokens = in_context_tokens + self.in_context_posemb.to(in_context_tokens.dtype) + x = torch.cat([in_context_tokens, x], dim=1) + + rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext + + if self.training and self.gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + block, + x, + c, + rope, + use_reentrant=False, + ) + else: + x = block(x, c, feat_rope=rope) + + # Slice off in-context tokens + if self.in_context_len > 0: + x = x[:, self.in_context_len :] + + # Final Layer + c = self.act_final(c) + shift, scale = self.adaLN_modulation_final(c).chunk(2, dim=1) + + x = modulate(self.norm_final(x), shift, scale) + x = self.linear_final(x) + + # Unpatchify + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels)) + x = torch.einsum("nhwpqc->nchpwq", x) + output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size)) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/JiT-H-32/model_index.json b/JiT-H-32/model_index.json index 3278c52c2b1f4dc3fc32c685d826632d4a9dd9df..fa18cbcc32203c64fd174626ff563c5f533fb945 100644 --- a/JiT-H-32/model_index.json +++ b/JiT-H-32/model_index.json @@ -1,8 +1,15 @@ { - "_class_name": "JiTPipeline", + "_class_name": [ + "pipeline", + "JiTPipeline" + ], "_diffusers_version": "0.36.0", + "scheduler": [ + "scheduling_jit", + "JiTScheduler" + ], "transformer": [ - "jit_diffusers", + "jit_transformer_2d", "JiTTransformer2DModel" ] -} \ No newline at end of file +} diff --git a/JiT-H-32/pipeline.py b/JiT-H-32/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..6196a7db5f800a02de65d7b100cf3474cc67dcf7 --- /dev/null +++ b/JiT-H-32/pipeline.py @@ -0,0 +1,460 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib +import json +import sys +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import torch + +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.utils.torch_utils import randn_tensor + + +RECOMMENDED_NOISE_BY_SIZE = { + 256: 1.0, + 512: 2.0, +} + + +class JiTPipeline(DiffusionPipeline): + r""" + Pipeline for image generation using JiT (Just image Transformer). + + Parameters: + transformer ([`JiTTransformer2DModel`]): + A class-conditioned `JiTTransformer2DModel` to denoise the images. + scheduler ([`JiTScheduler`]): + Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler). + id2label (`dict[int, str]`, *optional*): + ImageNet class id to English label mapping. Values may contain comma-separated synonyms. + id2label_cn (`dict[int, str]`, *optional*): + ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms. + """ + + model_cpu_offload_seq = "transformer" + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs): + """Load a self-contained variant folder locally or from the Hub. + + Examples: + JiTPipeline.from_pretrained(".") + JiTPipeline.from_pretrained("./JiT-H-32") + DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True) + """ + repo_root = Path(__file__).resolve().parent + + if pretrained_model_name_or_path in (None, "", "."): + variant = repo_root + elif ( + isinstance(pretrained_model_name_or_path, str) + and "/" in pretrained_model_name_or_path + and not Path(pretrained_model_name_or_path).exists() + ): + from huggingface_hub import snapshot_download + + hub_kwargs = dict(kwargs.pop("hub_kwargs", {})) + if subfolder: + hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"]) + cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs) + variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir) + else: + variant = Path(pretrained_model_name_or_path) + if not variant.is_absolute(): + candidate = (Path.cwd() / variant).resolve() + variant = candidate if candidate.exists() else (repo_root / variant).resolve() + if subfolder: + variant = variant / subfolder + + model_kwargs = dict(kwargs) + inserted: List[str] = [] + + def _load_component(folder: str, module_name: str, class_name: str): + comp_dir = variant / folder + module_path = comp_dir / f"{module_name}.py" + has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists() + if not module_path.exists() or not has_weights: + return None + + comp_path = str(comp_dir) + if comp_path not in sys.path: + sys.path.insert(0, comp_path) + inserted.append(comp_path) + + module = importlib.import_module(module_name) + component_cls = getattr(module, class_name) + return component_cls.from_pretrained(str(comp_dir), **model_kwargs) + + try: + transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel") + scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler") + + if transformer is None: + raise ValueError(f"No loadable transformer found under {variant}") + + variant_path = str(variant) + id2label, id2label_cn = cls._load_labels_for_variant(variant_path) + + pipe = cls( + transformer=transformer, + scheduler=scheduler, + id2label=id2label, + id2label_cn=id2label_cn, + ) + if variant_path and hasattr(pipe, "register_to_config"): + pipe.register_to_config(_name_or_path=variant_path) + return pipe + finally: + for comp_path in inserted: + if comp_path in sys.path: + sys.path.remove(comp_path) + + def __init__( + self, + transformer, + scheduler, + id2label: Optional[Dict[int, str]] = None, + id2label_cn: Optional[Dict[int, str]] = None, + ): + super().__init__() + self.register_modules(transformer=transformer, scheduler=scheduler) + + self._id2label = id2label or {} + self._id2label_cn = id2label_cn or {} + self.labels = self._build_label2id(self._id2label) + self.labels_cn = self._build_label2id(self._id2label_cn) + + def _ensure_labels_loaded(self) -> None: + if self._id2label or self._id2label_cn: + return + loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None)) + if loaded_en: + self._id2label = loaded_en + self.labels = self._build_label2id(self._id2label) + if loaded_cn: + self._id2label_cn = loaded_cn + self.labels_cn = self._build_label2id(self._id2label_cn) + + @staticmethod + def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]: + if not variant_path: + return None + variant_dir = Path(variant_path).resolve() + labels_dir = variant_dir.parent / "labels" + return labels_dir if labels_dir.is_dir() else None + + @staticmethod + def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]: + filename = "id2label_en.json" if lang == "en" else "id2label_cn.json" + path = labels_dir / filename + if not path.exists(): + raise FileNotFoundError(path) + raw = json.loads(path.read_text(encoding="utf-8")) + return {int(key): value for key, value in raw.items()} + + @classmethod + def _load_labels_for_variant( + cls, + variant_path: Optional[str], + ) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]: + labels_dir = cls._labels_dir_for_variant(variant_path) + if labels_dir is None: + return None, None + try: + return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn") + except FileNotFoundError: + return None, None + + @staticmethod + def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]: + label2id: Dict[str, int] = {} + for class_id, value in id2label.items(): + for synonym in value.split(","): + synonym = synonym.strip() + if synonym: + label2id[synonym] = int(class_id) + return dict(sorted(label2id.items())) + + @property + def id2label(self) -> Dict[int, str]: + """ImageNet class id to English label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label + + @property + def id2label_cn(self) -> Dict[int, str]: + """ImageNet class id to Chinese label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label_cn + + def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]: + r""" + Map ImageNet label strings to class ids. + + Args: + label (`str` or `list[str]`): + One or more label strings. Each string must match a synonym in `id2label` (English) + or `id2label_cn` (Chinese). + lang (`str`, *optional*, defaults to `"en"`): + `"en"` uses English synonyms; `"cn"` uses Chinese synonyms. + """ + if lang not in ("en", "cn"): + raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.") + + self._ensure_labels_loaded() + label2id = self.labels if lang == "en" else self.labels_cn + if not label2id: + raise ValueError( + f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder." + ) + + if isinstance(label, str): + label = [label] + + missing = [item for item in label if item not in label2id] + if missing: + preview = ", ".join(list(label2id.keys())[:8]) + raise ValueError( + f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..." + ) + return [label2id[item] for item in label] + + def _normalize_class_labels( + self, + class_labels: Union[int, str, List[Union[int, str]]], + ) -> List[int]: + if isinstance(class_labels, int): + return [class_labels] + + if isinstance(class_labels, str): + return self.get_label_ids(class_labels) + + if class_labels and isinstance(class_labels[0], str): + self._ensure_labels_loaded() + if all(label in self.labels for label in class_labels): + return self.get_label_ids(class_labels, lang="en") + if all(label in self.labels_cn for label in class_labels): + return self.get_label_ids(class_labels, lang="cn") + raise ValueError( + "Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` " + "or Chinese synonyms from `pipe.labels_cn`." + ) + + return list(class_labels) + + def _predict_velocity( + self, + z_value: torch.Tensor, + t: torch.Tensor, + class_labels: torch.Tensor, + class_null: torch.Tensor, + do_classifier_free_guidance: bool, + guidance_scale: float, + guidance_interval_min: float, + guidance_interval_max: float, + ) -> torch.Tensor: + t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype) + if do_classifier_free_guidance: + z_in = torch.cat([z_value, z_value], dim=0) + labels = torch.cat([class_labels, class_null], dim=0) + else: + z_in = z_value + labels = class_labels + + t_batch = t.flatten().expand(z_in.shape[0]) + x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample + v = self.scheduler.velocity_from_prediction(z_in, x_pred, t) + + if not do_classifier_free_guidance: + return v + + v_cond, v_uncond = v.chunk(2, dim=0) + interval_mask = t < guidance_interval_max + if guidance_interval_min != 0.0: + interval_mask = interval_mask & (t > guidance_interval_min) + scale = torch.where( + interval_mask, + torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype), + torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype), + ) + return v_uncond + scale * (v_cond - v_uncond) + + def _run_sampler( + self, + latents: torch.Tensor, + class_labels: torch.Tensor, + class_null: torch.Tensor, + num_inference_steps: int, + do_classifier_free_guidance: bool, + guidance_scale: float, + guidance_interval_min: float, + guidance_interval_max: float, + sampling_method: str, + ) -> torch.Tensor: + device = latents.device + self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method) + timesteps = self.scheduler.timesteps + + for i in self.progress_bar(range(num_inference_steps - 1)): + t = timesteps[i] + t_next = timesteps[i + 1] + v = self._predict_velocity( + latents, + t, + class_labels, + class_null, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + ) + + if sampling_method == "heun": + latents_euler = latents + (t_next - t) * v + v_next = self._predict_velocity( + latents_euler, + t_next, + class_labels, + class_null, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + ) + latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample + else: + latents = self.scheduler.step(v, t, latents).prev_sample + + t = timesteps[-2] + t_next = timesteps[-1] + v = self._predict_velocity( + latents, + t, + class_labels, + class_null, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + ) + return latents + (t_next - t) * v + + @torch.inference_mode() + def __call__( + self, + class_labels: Union[int, str, List[Union[int, str]]], + guidance_scale: Optional[float] = None, + guidance_interval_min: float = 0.1, + guidance_interval_max: float = 1.0, + noise_scale: Optional[float] = None, + t_eps: Optional[float] = None, + sampling_method: Optional[str] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Generate class-conditional images. + + Args: + class_labels (`int`, `str`, `list[int]`, or `list[str]`): + ImageNet class indices or human-readable label strings (English or Chinese). + guidance_scale (`float`, *optional*): + Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`. + guidance_interval_min (`float`, defaults to `0.1`): + Lower bound of the CFG interval in flow time `t in [0, 1]`. + guidance_interval_max (`float`, defaults to `1.0`): + Upper bound of the CFG interval in flow time. + noise_scale (`float`, *optional*): + Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default). + t_eps (`float`, *optional*): + Epsilon clamp for the `1 - t` denominator (scheduler config by default). + sampling_method (`str`, *optional*): + `"heun"` or `"euler"`. Defaults to the scheduler config (`heun`). + generator (`torch.Generator`, *optional*): + RNG for reproducibility. + num_inference_steps (`int`, defaults to `50`): + Number of solver steps (at least 2). + output_type (`str`, *optional*, defaults to `"pil"`): + `"pil"`, `"np"`, or `"pt"`. + return_dict (`bool`, *optional*, defaults to `True`): + Return [`ImagePipelineOutput`] if True. + """ + solver = sampling_method or self.scheduler.config.solver + if solver not in {"heun", "euler"}: + raise ValueError("sampling_method must be one of: 'heun', 'euler'.") + if num_inference_steps < 2: + raise ValueError("num_inference_steps must be >= 2.") + + if t_eps is not None: + self.scheduler.register_to_config(t_eps=t_eps) + + class_label_ids = self._normalize_class_labels(class_labels) + do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0 + + batch_size = len(class_label_ids) + image_size = int(self.transformer.config.sample_size) + channels = int(self.transformer.config.in_channels) + null_class_val = int(self.transformer.config.num_classes) + + if guidance_scale is None: + guidance_scale = 1.0 + if noise_scale is None: + noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0) + + latents = ( + randn_tensor( + shape=(batch_size, channels, image_size, image_size), + generator=generator, + device=self._execution_device, + dtype=self.transformer.dtype, + ) + * noise_scale + ) + + class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1) + class_labels_t = class_labels_t.clamp(0, null_class_val - 1) + class_null = torch.full_like(class_labels_t, null_class_val) + + latents = self._run_sampler( + latents, + class_labels_t, + class_null, + num_inference_steps, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + solver, + ) + + images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu() + if output_type == "pt": + images = images_pt + elif output_type == "np": + images = images_pt.permute(0, 2, 3, 1).numpy() + else: + images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy()) + + self.maybe_free_model_hooks() + + if not return_dict: + return (images,) + return ImagePipelineOutput(images=images) diff --git a/JiT-H-32/scheduler/scheduler_config.json b/JiT-H-32/scheduler/scheduler_config.json new file mode 100644 index 0000000000000000000000000000000000000000..bb871ad8071d8be4699f5246288de0a17963a5c4 --- /dev/null +++ b/JiT-H-32/scheduler/scheduler_config.json @@ -0,0 +1,7 @@ +{ + "_class_name": "JiTScheduler", + "_diffusers_version": "0.36.0", + "num_train_timesteps": 1000, + "t_eps": 0.05, + "solver": "heun" +} diff --git a/JiT-H-32/scheduler/scheduling_jit.py b/JiT-H-32/scheduler/scheduling_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5c55890f3446c190ca847f204264b4b8cbbbbb --- /dev/null +++ b/JiT-H-32/scheduler/scheduling_jit.py @@ -0,0 +1,161 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput + + +@dataclass +class JiTSchedulerOutput(BaseOutput): + """ + Output class for the JiT scheduler's `step` function. + + Args: + prev_sample (`torch.Tensor`): + Updated sample after one solver step along the JiT flow-time grid. + """ + + prev_sample: torch.Tensor + + +class JiTScheduler(SchedulerMixin, ConfigMixin): + """ + Manual flow-matching scheduler for JiT checkpoints. + + Uses a linear flow-time grid `t in [0, 1]` (increasing), matching the official JiT + sampler. Velocity is `v = (x_pred - z) / (1 - t)`; integration is explicit Euler or + Heun along that grid. + """ + + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + t_eps: float = 5e-2, + solver: str = "heun", + ): + if solver not in {"heun", "euler"}: + raise ValueError("solver must be one of: 'heun', 'euler'.") + self.timesteps: Optional[torch.Tensor] = None + self.sigmas: Optional[List[float]] = None + self.num_inference_steps: Optional[int] = None + self._step_index: Optional[int] = None + + @property + def init_noise_sigma(self) -> float: + return 1.0 + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device, None] = None, + solver: Optional[str] = None, + ) -> None: + if num_inference_steps < 2: + raise ValueError("num_inference_steps must be >= 2.") + + self.num_inference_steps = num_inference_steps + self.timesteps = torch.linspace( + 0.0, + 1.0, + num_inference_steps + 1, + device=device, + dtype=torch.float32, + ) + sigma_grid = torch.linspace(0.0, 1.0, num_inference_steps, device=device, dtype=torch.float32) + self.sigmas = (1.0 - sigma_grid).tolist() + self._step_index = 0 + if solver is not None: + self.register_to_config(solver=solver) + + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + del timestep + return sample + + def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int: + if self._step_index is not None: + return self._step_index + if self.timesteps is None: + raise ValueError("Call `set_timesteps` before `step`.") + if timestep is None: + return 0 + t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0]) + matches = (self.timesteps - t_value).abs() < 1e-6 + if matches.any(): + return int(matches.nonzero(as_tuple=False)[0].item()) + return 0 + + def step( + self, + model_output: torch.Tensor, + timestep: Union[float, torch.Tensor, None], + sample: torch.Tensor, + model_output_next: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[JiTSchedulerOutput, Tuple[torch.Tensor]]: + """ + Integrate one step on the linear `t` grid. + + Args: + model_output (`torch.Tensor`): + Velocity `v = (x_pred - z) / (1 - t)` at the current time. + timestep (`float` or `torch.Tensor`, *optional*): + Current flow time `t`. When omitted, uses the internal step index. + sample (`torch.Tensor`): + Current noisy latent `z`. + model_output_next (`torch.Tensor`, *optional*): + Velocity at `t_next` (required for Heun intermediate steps). + """ + if self.timesteps is None: + raise ValueError("Call `set_timesteps` before `step`.") + + step_index = self._resolve_step_index(timestep) + if step_index >= len(self.timesteps) - 1: + raise ValueError("Scheduler has already reached the final timestep.") + + t = self.timesteps[step_index] + t_next = self.timesteps[step_index + 1] + dt = t_next - t + + if self.config.solver == "heun" and model_output_next is not None: + prev_sample = sample + dt * 0.5 * (model_output + model_output_next) + else: + prev_sample = sample + dt * model_output + + self._step_index = step_index + 1 + + if not return_dict: + return (prev_sample,) + return JiTSchedulerOutput(prev_sample=prev_sample) + + def velocity_from_prediction( + self, + sample: torch.Tensor, + x_pred: torch.Tensor, + timestep: Union[float, torch.Tensor], + ) -> torch.Tensor: + """Compute JiT velocity `v = (x_pred - z) / (1 - t)` with denominator clamp.""" + t = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype) + while t.ndim < sample.ndim: + t = t.unsqueeze(-1) + denom = (1.0 - t).clamp_min(self.config.t_eps) + return (x_pred - sample) / denom diff --git a/JiT-H-32/transformer/config.json b/JiT-H-32/transformer/config.json new file mode 100644 index 0000000000000000000000000000000000000000..ce1cada392addd3f5b947fbff53347f616a2b628 --- /dev/null +++ b/JiT-H-32/transformer/config.json @@ -0,0 +1,18 @@ +{ + "_class_name": "JiTTransformer2DModel", + "_diffusers_version": "0.36.0", + "attention_dropout": 0.0, + "bottleneck_dim": 256, + "dropout": 0.2, + "hidden_size": 1280, + "in_channels": 3, + "in_context_len": 32, + "in_context_start": 10, + "mlp_ratio": 4.0, + "norm_eps": 1e-06, + "num_attention_heads": 16, + "num_classes": 1000, + "num_layers": 32, + "patch_size": 32, + "sample_size": 512 +} diff --git a/JiT-H-32/transformer/diffusion_pytorch_model.safetensors b/JiT-H-32/transformer/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..482e445ea882a8783e8eeaae50aa45ddc38b70d5 --- /dev/null +++ b/JiT-H-32/transformer/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:578fc2f9f4ccaa34c3d2f5076811e101419e5dfd1b20dcca89bbfb29f5f60ab6 +size 3825578920 diff --git a/JiT-H-32/transformer/jit_transformer_2d.py b/JiT-H-32/transformer/jit_transformer_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..3af0b8632931f4d42d78f8f9ced62d868e070e43 --- /dev/null +++ b/JiT-H-32/transformer/jit_transformer_2d.py @@ -0,0 +1,500 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import RMSNorm +from diffusers.utils import logging + +logger = logging.get_logger(__name__) + + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = {len(t.shape) for t in tensors} + if len(shape_lens) != 1: + raise ValueError("tensors must all have the same number of dimensions") + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*(list(t.shape) for t in tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + + if not all(len(set(t[1])) <= 2 for t in expandable_dims): + raise ValueError("invalid dimensions for broadcastable concatenation") + + max_dims = [(t[0], max(t[1])) for t in expandable_dims] + expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims] + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*(t[1] for t in expanded_dims))) + tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)] + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return x.view(*x.shape[:-2], -1) + + +class JiTRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + pt_seq_len=16, + ft_seq_len=None, + custom_freqs=None, + theta=10000, + num_cls_token=0, + ): + super().__init__() + if custom_freqs is not None: + freqs = custom_freqs + else: + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs = torch.einsum("..., f -> ... f", t, freqs) + freqs = freqs.repeat_interleave(2, dim=-1) + freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) + + if num_cls_token > 0: + freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D] + cos_img = freqs_flat.cos() + sin_img = freqs_flat.sin() + + # prepend in-context cls token + _, D = cos_img.shape + cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype) + sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype) + + self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False) + self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False) + else: + self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False) + self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False) + + def forward(self, t): + # Applied on (batch, seq_len, heads, head_dim) tensors from attention. + seq_len = t.shape[1] + freqs_cos = self.freqs_cos[:seq_len].to(t.dtype) + freqs_sin = self.freqs_sin[:seq_len].to(t.dtype) + + return t * freqs_cos[:, None, :] + rotate_half(t) * freqs_sin[:, None, :] + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class JiTPatchEmbed(nn.Module): + """Image to Patch Embedding with Bottleneck""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + + self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False) + self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias) + + def forward(self, x): + x = self.proj2(self.proj1(x)).flatten(2).transpose(1, 2) + return x + + +class JiTTimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t, dtype=None): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + if dtype is not None: + t_freq = t_freq.to(dtype=dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class JiTLabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. + """ + + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes + 1, hidden_size) + self.num_classes = num_classes + + def forward(self, labels): + embeddings = self.embedding_table(labels) + return embeddings + + +class JiTAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0.0, proj_drop=0.0, eps=1e-6): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.q_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + self.k_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rope=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = self.q_norm(q) + k = self.k_norm(k) + + if rope is not None: + q = q.transpose(1, 2) + k = k.transpose(1, 2) + q = rope(q) + k = rope(k) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + dropout_p = self.attn_drop if self.training else 0.0 + x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class JiTSwiGLUFFN(nn.Module): + def __init__(self, dim: int, hidden_dim: int, drop=0.0, bias=True) -> None: + super().__init__() + hidden_dim = int(hidden_dim * 2 / 3) + self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias) + self.w3 = nn.Linear(hidden_dim, dim, bias=bias) + self.ffn_dropout = nn.Dropout(drop) + + def forward(self, x): + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(self.ffn_dropout(hidden)) + + +class JiTBlock(nn.Module): + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, eps=1e-6): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=eps) + self.attn = JiTAttention( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + qk_norm=True, + attn_drop=attn_drop, + proj_drop=proj_drop, + eps=eps, + ) + self.norm2 = RMSNorm(hidden_size, eps=eps) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = JiTSwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop) + + self.act = nn.SiLU() + self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True) + + def forward(self, x, c, feat_rope=None): + # Apply activation + c = self.act(c) + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + + # Attention block + norm_x = self.norm1(x) + modulated_x = modulate(norm_x, shift_msa, scale_msa) + attn_out = self.attn(modulated_x, rope=feat_rope) + x = x + gate_msa.unsqueeze(1) * attn_out + + # MLP block + norm_x = self.norm2(x) + modulated_x = modulate(norm_x, shift_mlp, scale_mlp) + mlp_out = self.mlp(modulated_x) + x = x + gate_mlp.unsqueeze(1) * mlp_out + + return x + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0) + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}") + + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) + emb = np.concatenate([emb_h, emb_w], axis=1) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + if embed_dim % 2 != 0: + raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega + + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) + + emb_sin = np.sin(out) + emb_cos = np.cos(out) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) + return emb + + +class JiTTransformer2DModel(ModelMixin, ConfigMixin): + r""" + A 2D Transformer for pixel-space class-conditional generation with JiT + ([Back to Basics: Let Denoising Generative Models Denoise](https://arxiv.org/abs/2511.13720)). + + Parameters: + sample_size (`int`, defaults to `256`): + Input image resolution (height and width). + patch_size (`int`, defaults to `16`): + Patch size for the bottleneck patch embedder. + in_channels (`int`, defaults to `3`): + Number of input image channels. + hidden_size (`int`, defaults to `768`): + Transformer hidden dimension. + num_layers (`int`, defaults to `12`): + Number of JiT transformer blocks. + num_attention_heads (`int`, defaults to `12`): + Number of attention heads per block. + mlp_ratio (`float`, defaults to `4.0`): + MLP hidden dimension multiplier. + attention_dropout (`float`, defaults to `0.0`): + Attention dropout in the middle quarter of blocks. + dropout (`float`, defaults to `0.0`): + Projection dropout in the middle quarter of blocks. + num_classes (`int`, defaults to `1000`): + Number of class labels (null label uses index `num_classes` for CFG). + bottleneck_dim (`int`, defaults to `128`): + PCA bottleneck dimension in the patch embedder. + in_context_len (`int`, defaults to `32`): + Number of in-context class tokens prepended mid-network. + in_context_start (`int`, defaults to `4`): + Block index at which in-context tokens are inserted. + norm_eps (`float`, defaults to `1e-6`): + Epsilon for RMSNorm layers. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + + @register_to_config + def __init__( + self, + sample_size: int = 256, + patch_size: int = 16, + in_channels: int = 3, + hidden_size: int = 768, + num_layers: int = 12, + num_attention_heads: int = 12, + mlp_ratio: float = 4.0, + attention_dropout: float = 0.0, + dropout: float = 0.0, + num_classes: int = 1000, + bottleneck_dim: int = 128, + in_context_len: int = 32, + in_context_start: int = 4, + norm_eps: float = 1e-6, + ): + super().__init__() + self.sample_size = sample_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.in_context_len = in_context_len + self.in_context_start = in_context_start + self.norm_eps = norm_eps + self.gradient_checkpointing = False + + # Time and Class Embedding + self.t_embedder = JiTTimestepEmbedder(hidden_size) + self.y_embedder = JiTLabelEmbedder(num_classes, hidden_size) + + # Patch Embedding + self.x_embedder = JiTPatchEmbed( + img_size=sample_size, + patch_size=patch_size, + in_chans=in_channels, + pca_dim=bottleneck_dim, + embed_dim=hidden_size, + bias=True, + ) + + # Positional Embedding (Fixed Sin-Cos) + num_patches = self.x_embedder.num_patches + pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches**0.5)) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True) + + # In-context Embedding + if self.in_context_len > 0: + self.in_context_posemb = nn.Parameter(torch.zeros(1, self.in_context_len, hidden_size)) + + # RoPE + half_head_dim = hidden_size // num_attention_heads // 2 + hw_seq_len = sample_size // patch_size + self.feat_rope = JiTRotaryEmbedding(dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=0) + self.feat_rope_incontext = JiTRotaryEmbedding( + dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=self.in_context_len + ) + + # Blocks + self.blocks = nn.ModuleList( + [ + JiTBlock( + hidden_size, + num_attention_heads, + mlp_ratio=mlp_ratio, + attn_drop=attention_dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0, + proj_drop=dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0, + eps=norm_eps, + ) + for i in range(num_layers) + ] + ) + + # Final Layer + self.norm_final = RMSNorm(hidden_size, eps=norm_eps) + self.linear_final = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True) + self.act_final = nn.SiLU() + self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + class_labels: torch.LongTensor, + return_dict: bool = True, + ): + + t_emb = self.t_embedder(timestep, dtype=hidden_states.dtype) + y_emb = self.y_embedder(class_labels) + + # Ensure embeddings match hidden_states dtype + y_emb = y_emb.to(dtype=hidden_states.dtype) + + c = t_emb + y_emb + + # Patch Embed + x = self.x_embedder(hidden_states) + x = x + self.pos_embed.to(x.dtype) + + # Blocks + for i, block in enumerate(self.blocks): + if self.in_context_len > 0 and i == self.in_context_start: + in_context_tokens = y_emb.unsqueeze(1).repeat(1, self.in_context_len, 1) + in_context_tokens = in_context_tokens + self.in_context_posemb.to(in_context_tokens.dtype) + x = torch.cat([in_context_tokens, x], dim=1) + + rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext + + if self.training and self.gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + block, + x, + c, + rope, + use_reentrant=False, + ) + else: + x = block(x, c, feat_rope=rope) + + # Slice off in-context tokens + if self.in_context_len > 0: + x = x[:, self.in_context_len :] + + # Final Layer + c = self.act_final(c) + shift, scale = self.adaLN_modulation_final(c).chunk(2, dim=1) + + x = modulate(self.norm_final(x), shift, scale) + x = self.linear_final(x) + + # Unpatchify + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels)) + x = torch.einsum("nhwpqc->nchpwq", x) + output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size)) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/JiT-L-16/model_index.json b/JiT-L-16/model_index.json index 20edd748a56adb768b31321efe2e2a1855c71ab3..fa18cbcc32203c64fd174626ff563c5f533fb945 100644 --- a/JiT-L-16/model_index.json +++ b/JiT-L-16/model_index.json @@ -1,8 +1,15 @@ { - "_class_name": "JiTPipeline", + "_class_name": [ + "pipeline", + "JiTPipeline" + ], "_diffusers_version": "0.36.0", + "scheduler": [ + "scheduling_jit", + "JiTScheduler" + ], "transformer": [ - "jit_diffusers", + "jit_transformer_2d", "JiTTransformer2DModel" ] } diff --git a/JiT-L-16/pipeline.py b/JiT-L-16/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..6196a7db5f800a02de65d7b100cf3474cc67dcf7 --- /dev/null +++ b/JiT-L-16/pipeline.py @@ -0,0 +1,460 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib +import json +import sys +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import torch + +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.utils.torch_utils import randn_tensor + + +RECOMMENDED_NOISE_BY_SIZE = { + 256: 1.0, + 512: 2.0, +} + + +class JiTPipeline(DiffusionPipeline): + r""" + Pipeline for image generation using JiT (Just image Transformer). + + Parameters: + transformer ([`JiTTransformer2DModel`]): + A class-conditioned `JiTTransformer2DModel` to denoise the images. + scheduler ([`JiTScheduler`]): + Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler). + id2label (`dict[int, str]`, *optional*): + ImageNet class id to English label mapping. Values may contain comma-separated synonyms. + id2label_cn (`dict[int, str]`, *optional*): + ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms. + """ + + model_cpu_offload_seq = "transformer" + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs): + """Load a self-contained variant folder locally or from the Hub. + + Examples: + JiTPipeline.from_pretrained(".") + JiTPipeline.from_pretrained("./JiT-H-32") + DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True) + """ + repo_root = Path(__file__).resolve().parent + + if pretrained_model_name_or_path in (None, "", "."): + variant = repo_root + elif ( + isinstance(pretrained_model_name_or_path, str) + and "/" in pretrained_model_name_or_path + and not Path(pretrained_model_name_or_path).exists() + ): + from huggingface_hub import snapshot_download + + hub_kwargs = dict(kwargs.pop("hub_kwargs", {})) + if subfolder: + hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"]) + cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs) + variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir) + else: + variant = Path(pretrained_model_name_or_path) + if not variant.is_absolute(): + candidate = (Path.cwd() / variant).resolve() + variant = candidate if candidate.exists() else (repo_root / variant).resolve() + if subfolder: + variant = variant / subfolder + + model_kwargs = dict(kwargs) + inserted: List[str] = [] + + def _load_component(folder: str, module_name: str, class_name: str): + comp_dir = variant / folder + module_path = comp_dir / f"{module_name}.py" + has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists() + if not module_path.exists() or not has_weights: + return None + + comp_path = str(comp_dir) + if comp_path not in sys.path: + sys.path.insert(0, comp_path) + inserted.append(comp_path) + + module = importlib.import_module(module_name) + component_cls = getattr(module, class_name) + return component_cls.from_pretrained(str(comp_dir), **model_kwargs) + + try: + transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel") + scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler") + + if transformer is None: + raise ValueError(f"No loadable transformer found under {variant}") + + variant_path = str(variant) + id2label, id2label_cn = cls._load_labels_for_variant(variant_path) + + pipe = cls( + transformer=transformer, + scheduler=scheduler, + id2label=id2label, + id2label_cn=id2label_cn, + ) + if variant_path and hasattr(pipe, "register_to_config"): + pipe.register_to_config(_name_or_path=variant_path) + return pipe + finally: + for comp_path in inserted: + if comp_path in sys.path: + sys.path.remove(comp_path) + + def __init__( + self, + transformer, + scheduler, + id2label: Optional[Dict[int, str]] = None, + id2label_cn: Optional[Dict[int, str]] = None, + ): + super().__init__() + self.register_modules(transformer=transformer, scheduler=scheduler) + + self._id2label = id2label or {} + self._id2label_cn = id2label_cn or {} + self.labels = self._build_label2id(self._id2label) + self.labels_cn = self._build_label2id(self._id2label_cn) + + def _ensure_labels_loaded(self) -> None: + if self._id2label or self._id2label_cn: + return + loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None)) + if loaded_en: + self._id2label = loaded_en + self.labels = self._build_label2id(self._id2label) + if loaded_cn: + self._id2label_cn = loaded_cn + self.labels_cn = self._build_label2id(self._id2label_cn) + + @staticmethod + def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]: + if not variant_path: + return None + variant_dir = Path(variant_path).resolve() + labels_dir = variant_dir.parent / "labels" + return labels_dir if labels_dir.is_dir() else None + + @staticmethod + def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]: + filename = "id2label_en.json" if lang == "en" else "id2label_cn.json" + path = labels_dir / filename + if not path.exists(): + raise FileNotFoundError(path) + raw = json.loads(path.read_text(encoding="utf-8")) + return {int(key): value for key, value in raw.items()} + + @classmethod + def _load_labels_for_variant( + cls, + variant_path: Optional[str], + ) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]: + labels_dir = cls._labels_dir_for_variant(variant_path) + if labels_dir is None: + return None, None + try: + return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn") + except FileNotFoundError: + return None, None + + @staticmethod + def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]: + label2id: Dict[str, int] = {} + for class_id, value in id2label.items(): + for synonym in value.split(","): + synonym = synonym.strip() + if synonym: + label2id[synonym] = int(class_id) + return dict(sorted(label2id.items())) + + @property + def id2label(self) -> Dict[int, str]: + """ImageNet class id to English label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label + + @property + def id2label_cn(self) -> Dict[int, str]: + """ImageNet class id to Chinese label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label_cn + + def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]: + r""" + Map ImageNet label strings to class ids. + + Args: + label (`str` or `list[str]`): + One or more label strings. Each string must match a synonym in `id2label` (English) + or `id2label_cn` (Chinese). + lang (`str`, *optional*, defaults to `"en"`): + `"en"` uses English synonyms; `"cn"` uses Chinese synonyms. + """ + if lang not in ("en", "cn"): + raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.") + + self._ensure_labels_loaded() + label2id = self.labels if lang == "en" else self.labels_cn + if not label2id: + raise ValueError( + f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder." + ) + + if isinstance(label, str): + label = [label] + + missing = [item for item in label if item not in label2id] + if missing: + preview = ", ".join(list(label2id.keys())[:8]) + raise ValueError( + f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..." + ) + return [label2id[item] for item in label] + + def _normalize_class_labels( + self, + class_labels: Union[int, str, List[Union[int, str]]], + ) -> List[int]: + if isinstance(class_labels, int): + return [class_labels] + + if isinstance(class_labels, str): + return self.get_label_ids(class_labels) + + if class_labels and isinstance(class_labels[0], str): + self._ensure_labels_loaded() + if all(label in self.labels for label in class_labels): + return self.get_label_ids(class_labels, lang="en") + if all(label in self.labels_cn for label in class_labels): + return self.get_label_ids(class_labels, lang="cn") + raise ValueError( + "Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` " + "or Chinese synonyms from `pipe.labels_cn`." + ) + + return list(class_labels) + + def _predict_velocity( + self, + z_value: torch.Tensor, + t: torch.Tensor, + class_labels: torch.Tensor, + class_null: torch.Tensor, + do_classifier_free_guidance: bool, + guidance_scale: float, + guidance_interval_min: float, + guidance_interval_max: float, + ) -> torch.Tensor: + t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype) + if do_classifier_free_guidance: + z_in = torch.cat([z_value, z_value], dim=0) + labels = torch.cat([class_labels, class_null], dim=0) + else: + z_in = z_value + labels = class_labels + + t_batch = t.flatten().expand(z_in.shape[0]) + x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample + v = self.scheduler.velocity_from_prediction(z_in, x_pred, t) + + if not do_classifier_free_guidance: + return v + + v_cond, v_uncond = v.chunk(2, dim=0) + interval_mask = t < guidance_interval_max + if guidance_interval_min != 0.0: + interval_mask = interval_mask & (t > guidance_interval_min) + scale = torch.where( + interval_mask, + torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype), + torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype), + ) + return v_uncond + scale * (v_cond - v_uncond) + + def _run_sampler( + self, + latents: torch.Tensor, + class_labels: torch.Tensor, + class_null: torch.Tensor, + num_inference_steps: int, + do_classifier_free_guidance: bool, + guidance_scale: float, + guidance_interval_min: float, + guidance_interval_max: float, + sampling_method: str, + ) -> torch.Tensor: + device = latents.device + self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method) + timesteps = self.scheduler.timesteps + + for i in self.progress_bar(range(num_inference_steps - 1)): + t = timesteps[i] + t_next = timesteps[i + 1] + v = self._predict_velocity( + latents, + t, + class_labels, + class_null, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + ) + + if sampling_method == "heun": + latents_euler = latents + (t_next - t) * v + v_next = self._predict_velocity( + latents_euler, + t_next, + class_labels, + class_null, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + ) + latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample + else: + latents = self.scheduler.step(v, t, latents).prev_sample + + t = timesteps[-2] + t_next = timesteps[-1] + v = self._predict_velocity( + latents, + t, + class_labels, + class_null, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + ) + return latents + (t_next - t) * v + + @torch.inference_mode() + def __call__( + self, + class_labels: Union[int, str, List[Union[int, str]]], + guidance_scale: Optional[float] = None, + guidance_interval_min: float = 0.1, + guidance_interval_max: float = 1.0, + noise_scale: Optional[float] = None, + t_eps: Optional[float] = None, + sampling_method: Optional[str] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Generate class-conditional images. + + Args: + class_labels (`int`, `str`, `list[int]`, or `list[str]`): + ImageNet class indices or human-readable label strings (English or Chinese). + guidance_scale (`float`, *optional*): + Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`. + guidance_interval_min (`float`, defaults to `0.1`): + Lower bound of the CFG interval in flow time `t in [0, 1]`. + guidance_interval_max (`float`, defaults to `1.0`): + Upper bound of the CFG interval in flow time. + noise_scale (`float`, *optional*): + Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default). + t_eps (`float`, *optional*): + Epsilon clamp for the `1 - t` denominator (scheduler config by default). + sampling_method (`str`, *optional*): + `"heun"` or `"euler"`. Defaults to the scheduler config (`heun`). + generator (`torch.Generator`, *optional*): + RNG for reproducibility. + num_inference_steps (`int`, defaults to `50`): + Number of solver steps (at least 2). + output_type (`str`, *optional*, defaults to `"pil"`): + `"pil"`, `"np"`, or `"pt"`. + return_dict (`bool`, *optional*, defaults to `True`): + Return [`ImagePipelineOutput`] if True. + """ + solver = sampling_method or self.scheduler.config.solver + if solver not in {"heun", "euler"}: + raise ValueError("sampling_method must be one of: 'heun', 'euler'.") + if num_inference_steps < 2: + raise ValueError("num_inference_steps must be >= 2.") + + if t_eps is not None: + self.scheduler.register_to_config(t_eps=t_eps) + + class_label_ids = self._normalize_class_labels(class_labels) + do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0 + + batch_size = len(class_label_ids) + image_size = int(self.transformer.config.sample_size) + channels = int(self.transformer.config.in_channels) + null_class_val = int(self.transformer.config.num_classes) + + if guidance_scale is None: + guidance_scale = 1.0 + if noise_scale is None: + noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0) + + latents = ( + randn_tensor( + shape=(batch_size, channels, image_size, image_size), + generator=generator, + device=self._execution_device, + dtype=self.transformer.dtype, + ) + * noise_scale + ) + + class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1) + class_labels_t = class_labels_t.clamp(0, null_class_val - 1) + class_null = torch.full_like(class_labels_t, null_class_val) + + latents = self._run_sampler( + latents, + class_labels_t, + class_null, + num_inference_steps, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + solver, + ) + + images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu() + if output_type == "pt": + images = images_pt + elif output_type == "np": + images = images_pt.permute(0, 2, 3, 1).numpy() + else: + images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy()) + + self.maybe_free_model_hooks() + + if not return_dict: + return (images,) + return ImagePipelineOutput(images=images) diff --git a/JiT-L-16/scheduler/scheduler_config.json b/JiT-L-16/scheduler/scheduler_config.json new file mode 100644 index 0000000000000000000000000000000000000000..bb871ad8071d8be4699f5246288de0a17963a5c4 --- /dev/null +++ b/JiT-L-16/scheduler/scheduler_config.json @@ -0,0 +1,7 @@ +{ + "_class_name": "JiTScheduler", + "_diffusers_version": "0.36.0", + "num_train_timesteps": 1000, + "t_eps": 0.05, + "solver": "heun" +} diff --git a/JiT-L-16/scheduler/scheduling_jit.py b/JiT-L-16/scheduler/scheduling_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5c55890f3446c190ca847f204264b4b8cbbbbb --- /dev/null +++ b/JiT-L-16/scheduler/scheduling_jit.py @@ -0,0 +1,161 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput + + +@dataclass +class JiTSchedulerOutput(BaseOutput): + """ + Output class for the JiT scheduler's `step` function. + + Args: + prev_sample (`torch.Tensor`): + Updated sample after one solver step along the JiT flow-time grid. + """ + + prev_sample: torch.Tensor + + +class JiTScheduler(SchedulerMixin, ConfigMixin): + """ + Manual flow-matching scheduler for JiT checkpoints. + + Uses a linear flow-time grid `t in [0, 1]` (increasing), matching the official JiT + sampler. Velocity is `v = (x_pred - z) / (1 - t)`; integration is explicit Euler or + Heun along that grid. + """ + + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + t_eps: float = 5e-2, + solver: str = "heun", + ): + if solver not in {"heun", "euler"}: + raise ValueError("solver must be one of: 'heun', 'euler'.") + self.timesteps: Optional[torch.Tensor] = None + self.sigmas: Optional[List[float]] = None + self.num_inference_steps: Optional[int] = None + self._step_index: Optional[int] = None + + @property + def init_noise_sigma(self) -> float: + return 1.0 + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device, None] = None, + solver: Optional[str] = None, + ) -> None: + if num_inference_steps < 2: + raise ValueError("num_inference_steps must be >= 2.") + + self.num_inference_steps = num_inference_steps + self.timesteps = torch.linspace( + 0.0, + 1.0, + num_inference_steps + 1, + device=device, + dtype=torch.float32, + ) + sigma_grid = torch.linspace(0.0, 1.0, num_inference_steps, device=device, dtype=torch.float32) + self.sigmas = (1.0 - sigma_grid).tolist() + self._step_index = 0 + if solver is not None: + self.register_to_config(solver=solver) + + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + del timestep + return sample + + def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int: + if self._step_index is not None: + return self._step_index + if self.timesteps is None: + raise ValueError("Call `set_timesteps` before `step`.") + if timestep is None: + return 0 + t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0]) + matches = (self.timesteps - t_value).abs() < 1e-6 + if matches.any(): + return int(matches.nonzero(as_tuple=False)[0].item()) + return 0 + + def step( + self, + model_output: torch.Tensor, + timestep: Union[float, torch.Tensor, None], + sample: torch.Tensor, + model_output_next: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[JiTSchedulerOutput, Tuple[torch.Tensor]]: + """ + Integrate one step on the linear `t` grid. + + Args: + model_output (`torch.Tensor`): + Velocity `v = (x_pred - z) / (1 - t)` at the current time. + timestep (`float` or `torch.Tensor`, *optional*): + Current flow time `t`. When omitted, uses the internal step index. + sample (`torch.Tensor`): + Current noisy latent `z`. + model_output_next (`torch.Tensor`, *optional*): + Velocity at `t_next` (required for Heun intermediate steps). + """ + if self.timesteps is None: + raise ValueError("Call `set_timesteps` before `step`.") + + step_index = self._resolve_step_index(timestep) + if step_index >= len(self.timesteps) - 1: + raise ValueError("Scheduler has already reached the final timestep.") + + t = self.timesteps[step_index] + t_next = self.timesteps[step_index + 1] + dt = t_next - t + + if self.config.solver == "heun" and model_output_next is not None: + prev_sample = sample + dt * 0.5 * (model_output + model_output_next) + else: + prev_sample = sample + dt * model_output + + self._step_index = step_index + 1 + + if not return_dict: + return (prev_sample,) + return JiTSchedulerOutput(prev_sample=prev_sample) + + def velocity_from_prediction( + self, + sample: torch.Tensor, + x_pred: torch.Tensor, + timestep: Union[float, torch.Tensor], + ) -> torch.Tensor: + """Compute JiT velocity `v = (x_pred - z) / (1 - t)` with denominator clamp.""" + t = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype) + while t.ndim < sample.ndim: + t = t.unsqueeze(-1) + denom = (1.0 - t).clamp_min(self.config.t_eps) + return (x_pred - sample) / denom diff --git a/JiT-L-16/transformer/config.json b/JiT-L-16/transformer/config.json new file mode 100644 index 0000000000000000000000000000000000000000..9e0d13b6090911eeafaa0378a92f8f8615740369 --- /dev/null +++ b/JiT-L-16/transformer/config.json @@ -0,0 +1,18 @@ +{ + "_class_name": "JiTTransformer2DModel", + "_diffusers_version": "0.36.0", + "attention_dropout": 0.0, + "bottleneck_dim": 128, + "dropout": 0.0, + "hidden_size": 1024, + "in_channels": 3, + "in_context_len": 32, + "in_context_start": 8, + "mlp_ratio": 4.0, + "norm_eps": 1e-06, + "num_attention_heads": 16, + "num_classes": 1000, + "num_layers": 24, + "patch_size": 16, + "sample_size": 256 +} diff --git a/JiT-L-16/transformer/diffusion_pytorch_model.safetensors b/JiT-L-16/transformer/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..e8c9ff2fe1d3ca16f9cb9fb54a77d9b1c0f3df58 --- /dev/null +++ b/JiT-L-16/transformer/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9285393d92db078237e8adc552d6c9314c898c710ca1dfb4d3503fda0016b0f +size 1836593656 diff --git a/JiT-L-16/transformer/jit_transformer_2d.py b/JiT-L-16/transformer/jit_transformer_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..3af0b8632931f4d42d78f8f9ced62d868e070e43 --- /dev/null +++ b/JiT-L-16/transformer/jit_transformer_2d.py @@ -0,0 +1,500 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import RMSNorm +from diffusers.utils import logging + +logger = logging.get_logger(__name__) + + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = {len(t.shape) for t in tensors} + if len(shape_lens) != 1: + raise ValueError("tensors must all have the same number of dimensions") + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*(list(t.shape) for t in tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + + if not all(len(set(t[1])) <= 2 for t in expandable_dims): + raise ValueError("invalid dimensions for broadcastable concatenation") + + max_dims = [(t[0], max(t[1])) for t in expandable_dims] + expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims] + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*(t[1] for t in expanded_dims))) + tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)] + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return x.view(*x.shape[:-2], -1) + + +class JiTRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + pt_seq_len=16, + ft_seq_len=None, + custom_freqs=None, + theta=10000, + num_cls_token=0, + ): + super().__init__() + if custom_freqs is not None: + freqs = custom_freqs + else: + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs = torch.einsum("..., f -> ... f", t, freqs) + freqs = freqs.repeat_interleave(2, dim=-1) + freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) + + if num_cls_token > 0: + freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D] + cos_img = freqs_flat.cos() + sin_img = freqs_flat.sin() + + # prepend in-context cls token + _, D = cos_img.shape + cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype) + sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype) + + self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False) + self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False) + else: + self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False) + self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False) + + def forward(self, t): + # Applied on (batch, seq_len, heads, head_dim) tensors from attention. + seq_len = t.shape[1] + freqs_cos = self.freqs_cos[:seq_len].to(t.dtype) + freqs_sin = self.freqs_sin[:seq_len].to(t.dtype) + + return t * freqs_cos[:, None, :] + rotate_half(t) * freqs_sin[:, None, :] + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class JiTPatchEmbed(nn.Module): + """Image to Patch Embedding with Bottleneck""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + + self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False) + self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias) + + def forward(self, x): + x = self.proj2(self.proj1(x)).flatten(2).transpose(1, 2) + return x + + +class JiTTimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t, dtype=None): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + if dtype is not None: + t_freq = t_freq.to(dtype=dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class JiTLabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. + """ + + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes + 1, hidden_size) + self.num_classes = num_classes + + def forward(self, labels): + embeddings = self.embedding_table(labels) + return embeddings + + +class JiTAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0.0, proj_drop=0.0, eps=1e-6): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.q_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + self.k_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rope=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = self.q_norm(q) + k = self.k_norm(k) + + if rope is not None: + q = q.transpose(1, 2) + k = k.transpose(1, 2) + q = rope(q) + k = rope(k) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + dropout_p = self.attn_drop if self.training else 0.0 + x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class JiTSwiGLUFFN(nn.Module): + def __init__(self, dim: int, hidden_dim: int, drop=0.0, bias=True) -> None: + super().__init__() + hidden_dim = int(hidden_dim * 2 / 3) + self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias) + self.w3 = nn.Linear(hidden_dim, dim, bias=bias) + self.ffn_dropout = nn.Dropout(drop) + + def forward(self, x): + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(self.ffn_dropout(hidden)) + + +class JiTBlock(nn.Module): + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, eps=1e-6): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=eps) + self.attn = JiTAttention( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + qk_norm=True, + attn_drop=attn_drop, + proj_drop=proj_drop, + eps=eps, + ) + self.norm2 = RMSNorm(hidden_size, eps=eps) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = JiTSwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop) + + self.act = nn.SiLU() + self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True) + + def forward(self, x, c, feat_rope=None): + # Apply activation + c = self.act(c) + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + + # Attention block + norm_x = self.norm1(x) + modulated_x = modulate(norm_x, shift_msa, scale_msa) + attn_out = self.attn(modulated_x, rope=feat_rope) + x = x + gate_msa.unsqueeze(1) * attn_out + + # MLP block + norm_x = self.norm2(x) + modulated_x = modulate(norm_x, shift_mlp, scale_mlp) + mlp_out = self.mlp(modulated_x) + x = x + gate_mlp.unsqueeze(1) * mlp_out + + return x + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0) + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}") + + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) + emb = np.concatenate([emb_h, emb_w], axis=1) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + if embed_dim % 2 != 0: + raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega + + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) + + emb_sin = np.sin(out) + emb_cos = np.cos(out) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) + return emb + + +class JiTTransformer2DModel(ModelMixin, ConfigMixin): + r""" + A 2D Transformer for pixel-space class-conditional generation with JiT + ([Back to Basics: Let Denoising Generative Models Denoise](https://arxiv.org/abs/2511.13720)). + + Parameters: + sample_size (`int`, defaults to `256`): + Input image resolution (height and width). + patch_size (`int`, defaults to `16`): + Patch size for the bottleneck patch embedder. + in_channels (`int`, defaults to `3`): + Number of input image channels. + hidden_size (`int`, defaults to `768`): + Transformer hidden dimension. + num_layers (`int`, defaults to `12`): + Number of JiT transformer blocks. + num_attention_heads (`int`, defaults to `12`): + Number of attention heads per block. + mlp_ratio (`float`, defaults to `4.0`): + MLP hidden dimension multiplier. + attention_dropout (`float`, defaults to `0.0`): + Attention dropout in the middle quarter of blocks. + dropout (`float`, defaults to `0.0`): + Projection dropout in the middle quarter of blocks. + num_classes (`int`, defaults to `1000`): + Number of class labels (null label uses index `num_classes` for CFG). + bottleneck_dim (`int`, defaults to `128`): + PCA bottleneck dimension in the patch embedder. + in_context_len (`int`, defaults to `32`): + Number of in-context class tokens prepended mid-network. + in_context_start (`int`, defaults to `4`): + Block index at which in-context tokens are inserted. + norm_eps (`float`, defaults to `1e-6`): + Epsilon for RMSNorm layers. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + + @register_to_config + def __init__( + self, + sample_size: int = 256, + patch_size: int = 16, + in_channels: int = 3, + hidden_size: int = 768, + num_layers: int = 12, + num_attention_heads: int = 12, + mlp_ratio: float = 4.0, + attention_dropout: float = 0.0, + dropout: float = 0.0, + num_classes: int = 1000, + bottleneck_dim: int = 128, + in_context_len: int = 32, + in_context_start: int = 4, + norm_eps: float = 1e-6, + ): + super().__init__() + self.sample_size = sample_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.in_context_len = in_context_len + self.in_context_start = in_context_start + self.norm_eps = norm_eps + self.gradient_checkpointing = False + + # Time and Class Embedding + self.t_embedder = JiTTimestepEmbedder(hidden_size) + self.y_embedder = JiTLabelEmbedder(num_classes, hidden_size) + + # Patch Embedding + self.x_embedder = JiTPatchEmbed( + img_size=sample_size, + patch_size=patch_size, + in_chans=in_channels, + pca_dim=bottleneck_dim, + embed_dim=hidden_size, + bias=True, + ) + + # Positional Embedding (Fixed Sin-Cos) + num_patches = self.x_embedder.num_patches + pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches**0.5)) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True) + + # In-context Embedding + if self.in_context_len > 0: + self.in_context_posemb = nn.Parameter(torch.zeros(1, self.in_context_len, hidden_size)) + + # RoPE + half_head_dim = hidden_size // num_attention_heads // 2 + hw_seq_len = sample_size // patch_size + self.feat_rope = JiTRotaryEmbedding(dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=0) + self.feat_rope_incontext = JiTRotaryEmbedding( + dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=self.in_context_len + ) + + # Blocks + self.blocks = nn.ModuleList( + [ + JiTBlock( + hidden_size, + num_attention_heads, + mlp_ratio=mlp_ratio, + attn_drop=attention_dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0, + proj_drop=dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0, + eps=norm_eps, + ) + for i in range(num_layers) + ] + ) + + # Final Layer + self.norm_final = RMSNorm(hidden_size, eps=norm_eps) + self.linear_final = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True) + self.act_final = nn.SiLU() + self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + class_labels: torch.LongTensor, + return_dict: bool = True, + ): + + t_emb = self.t_embedder(timestep, dtype=hidden_states.dtype) + y_emb = self.y_embedder(class_labels) + + # Ensure embeddings match hidden_states dtype + y_emb = y_emb.to(dtype=hidden_states.dtype) + + c = t_emb + y_emb + + # Patch Embed + x = self.x_embedder(hidden_states) + x = x + self.pos_embed.to(x.dtype) + + # Blocks + for i, block in enumerate(self.blocks): + if self.in_context_len > 0 and i == self.in_context_start: + in_context_tokens = y_emb.unsqueeze(1).repeat(1, self.in_context_len, 1) + in_context_tokens = in_context_tokens + self.in_context_posemb.to(in_context_tokens.dtype) + x = torch.cat([in_context_tokens, x], dim=1) + + rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext + + if self.training and self.gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + block, + x, + c, + rope, + use_reentrant=False, + ) + else: + x = block(x, c, feat_rope=rope) + + # Slice off in-context tokens + if self.in_context_len > 0: + x = x[:, self.in_context_len :] + + # Final Layer + c = self.act_final(c) + shift, scale = self.adaLN_modulation_final(c).chunk(2, dim=1) + + x = modulate(self.norm_final(x), shift, scale) + x = self.linear_final(x) + + # Unpatchify + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels)) + x = torch.einsum("nhwpqc->nchpwq", x) + output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size)) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/JiT-L-32/model_index.json b/JiT-L-32/model_index.json index 20edd748a56adb768b31321efe2e2a1855c71ab3..fa18cbcc32203c64fd174626ff563c5f533fb945 100644 --- a/JiT-L-32/model_index.json +++ b/JiT-L-32/model_index.json @@ -1,8 +1,15 @@ { - "_class_name": "JiTPipeline", + "_class_name": [ + "pipeline", + "JiTPipeline" + ], "_diffusers_version": "0.36.0", + "scheduler": [ + "scheduling_jit", + "JiTScheduler" + ], "transformer": [ - "jit_diffusers", + "jit_transformer_2d", "JiTTransformer2DModel" ] } diff --git a/JiT-L-32/pipeline.py b/JiT-L-32/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..6196a7db5f800a02de65d7b100cf3474cc67dcf7 --- /dev/null +++ b/JiT-L-32/pipeline.py @@ -0,0 +1,460 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib +import json +import sys +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import torch + +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.utils.torch_utils import randn_tensor + + +RECOMMENDED_NOISE_BY_SIZE = { + 256: 1.0, + 512: 2.0, +} + + +class JiTPipeline(DiffusionPipeline): + r""" + Pipeline for image generation using JiT (Just image Transformer). + + Parameters: + transformer ([`JiTTransformer2DModel`]): + A class-conditioned `JiTTransformer2DModel` to denoise the images. + scheduler ([`JiTScheduler`]): + Manual JiT flow-matching scheduler (linear `t in [0, 1]`, Heun or Euler). + id2label (`dict[int, str]`, *optional*): + ImageNet class id to English label mapping. Values may contain comma-separated synonyms. + id2label_cn (`dict[int, str]`, *optional*): + ImageNet class id to Chinese label mapping. Values may contain comma-separated synonyms. + """ + + model_cpu_offload_seq = "transformer" + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs): + """Load a self-contained variant folder locally or from the Hub. + + Examples: + JiTPipeline.from_pretrained(".") + JiTPipeline.from_pretrained("./JiT-H-32") + DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True) + """ + repo_root = Path(__file__).resolve().parent + + if pretrained_model_name_or_path in (None, "", "."): + variant = repo_root + elif ( + isinstance(pretrained_model_name_or_path, str) + and "/" in pretrained_model_name_or_path + and not Path(pretrained_model_name_or_path).exists() + ): + from huggingface_hub import snapshot_download + + hub_kwargs = dict(kwargs.pop("hub_kwargs", {})) + if subfolder: + hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**", "labels/**"]) + cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs) + variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir) + else: + variant = Path(pretrained_model_name_or_path) + if not variant.is_absolute(): + candidate = (Path.cwd() / variant).resolve() + variant = candidate if candidate.exists() else (repo_root / variant).resolve() + if subfolder: + variant = variant / subfolder + + model_kwargs = dict(kwargs) + inserted: List[str] = [] + + def _load_component(folder: str, module_name: str, class_name: str): + comp_dir = variant / folder + module_path = comp_dir / f"{module_name}.py" + has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists() + if not module_path.exists() or not has_weights: + return None + + comp_path = str(comp_dir) + if comp_path not in sys.path: + sys.path.insert(0, comp_path) + inserted.append(comp_path) + + module = importlib.import_module(module_name) + component_cls = getattr(module, class_name) + return component_cls.from_pretrained(str(comp_dir), **model_kwargs) + + try: + transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel") + scheduler = _load_component("scheduler", "scheduling_jit", "JiTScheduler") + + if transformer is None: + raise ValueError(f"No loadable transformer found under {variant}") + + variant_path = str(variant) + id2label, id2label_cn = cls._load_labels_for_variant(variant_path) + + pipe = cls( + transformer=transformer, + scheduler=scheduler, + id2label=id2label, + id2label_cn=id2label_cn, + ) + if variant_path and hasattr(pipe, "register_to_config"): + pipe.register_to_config(_name_or_path=variant_path) + return pipe + finally: + for comp_path in inserted: + if comp_path in sys.path: + sys.path.remove(comp_path) + + def __init__( + self, + transformer, + scheduler, + id2label: Optional[Dict[int, str]] = None, + id2label_cn: Optional[Dict[int, str]] = None, + ): + super().__init__() + self.register_modules(transformer=transformer, scheduler=scheduler) + + self._id2label = id2label or {} + self._id2label_cn = id2label_cn or {} + self.labels = self._build_label2id(self._id2label) + self.labels_cn = self._build_label2id(self._id2label_cn) + + def _ensure_labels_loaded(self) -> None: + if self._id2label or self._id2label_cn: + return + loaded_en, loaded_cn = self._load_labels_for_variant(getattr(self.config, "_name_or_path", None)) + if loaded_en: + self._id2label = loaded_en + self.labels = self._build_label2id(self._id2label) + if loaded_cn: + self._id2label_cn = loaded_cn + self.labels_cn = self._build_label2id(self._id2label_cn) + + @staticmethod + def _labels_dir_for_variant(variant_path: Optional[str]) -> Optional[Path]: + if not variant_path: + return None + variant_dir = Path(variant_path).resolve() + labels_dir = variant_dir.parent / "labels" + return labels_dir if labels_dir.is_dir() else None + + @staticmethod + def _read_id2label(labels_dir: Path, lang: str = "en") -> Dict[int, str]: + filename = "id2label_en.json" if lang == "en" else "id2label_cn.json" + path = labels_dir / filename + if not path.exists(): + raise FileNotFoundError(path) + raw = json.loads(path.read_text(encoding="utf-8")) + return {int(key): value for key, value in raw.items()} + + @classmethod + def _load_labels_for_variant( + cls, + variant_path: Optional[str], + ) -> Tuple[Optional[Dict[int, str]], Optional[Dict[int, str]]]: + labels_dir = cls._labels_dir_for_variant(variant_path) + if labels_dir is None: + return None, None + try: + return cls._read_id2label(labels_dir, "en"), cls._read_id2label(labels_dir, "cn") + except FileNotFoundError: + return None, None + + @staticmethod + def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]: + label2id: Dict[str, int] = {} + for class_id, value in id2label.items(): + for synonym in value.split(","): + synonym = synonym.strip() + if synonym: + label2id[synonym] = int(class_id) + return dict(sorted(label2id.items())) + + @property + def id2label(self) -> Dict[int, str]: + """ImageNet class id to English label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label + + @property + def id2label_cn(self) -> Dict[int, str]: + """ImageNet class id to Chinese label string (comma-separated synonyms).""" + self._ensure_labels_loaded() + return self._id2label_cn + + def get_label_ids(self, label: Union[str, List[str]], lang: str = "en") -> List[int]: + r""" + Map ImageNet label strings to class ids. + + Args: + label (`str` or `list[str]`): + One or more label strings. Each string must match a synonym in `id2label` (English) + or `id2label_cn` (Chinese). + lang (`str`, *optional*, defaults to `"en"`): + `"en"` uses English synonyms; `"cn"` uses Chinese synonyms. + """ + if lang not in ("en", "cn"): + raise ValueError(f"`lang` must be 'en' or 'cn', got {lang!r}.") + + self._ensure_labels_loaded() + label2id = self.labels if lang == "en" else self.labels_cn + if not label2id: + raise ValueError( + f"No {lang} labels loaded. Ensure `labels/id2label_{lang}.json` exists next to the variant folder." + ) + + if isinstance(label, str): + label = [label] + + missing = [item for item in label if item not in label2id] + if missing: + preview = ", ".join(list(label2id.keys())[:8]) + raise ValueError( + f"Unknown label(s) for lang={lang!r}: {missing}. Example valid labels: {preview}, ..." + ) + return [label2id[item] for item in label] + + def _normalize_class_labels( + self, + class_labels: Union[int, str, List[Union[int, str]]], + ) -> List[int]: + if isinstance(class_labels, int): + return [class_labels] + + if isinstance(class_labels, str): + return self.get_label_ids(class_labels) + + if class_labels and isinstance(class_labels[0], str): + self._ensure_labels_loaded() + if all(label in self.labels for label in class_labels): + return self.get_label_ids(class_labels, lang="en") + if all(label in self.labels_cn for label in class_labels): + return self.get_label_ids(class_labels, lang="cn") + raise ValueError( + "Could not resolve string `class_labels`. Use English synonyms from `pipe.labels` " + "or Chinese synonyms from `pipe.labels_cn`." + ) + + return list(class_labels) + + def _predict_velocity( + self, + z_value: torch.Tensor, + t: torch.Tensor, + class_labels: torch.Tensor, + class_null: torch.Tensor, + do_classifier_free_guidance: bool, + guidance_scale: float, + guidance_interval_min: float, + guidance_interval_max: float, + ) -> torch.Tensor: + t = torch.as_tensor(t, device=z_value.device, dtype=z_value.dtype) + if do_classifier_free_guidance: + z_in = torch.cat([z_value, z_value], dim=0) + labels = torch.cat([class_labels, class_null], dim=0) + else: + z_in = z_value + labels = class_labels + + t_batch = t.flatten().expand(z_in.shape[0]) + x_pred = self.transformer(z_in, timestep=t_batch, class_labels=labels).sample + v = self.scheduler.velocity_from_prediction(z_in, x_pred, t) + + if not do_classifier_free_guidance: + return v + + v_cond, v_uncond = v.chunk(2, dim=0) + interval_mask = t < guidance_interval_max + if guidance_interval_min != 0.0: + interval_mask = interval_mask & (t > guidance_interval_min) + scale = torch.where( + interval_mask, + torch.tensor(guidance_scale, device=z_value.device, dtype=z_value.dtype), + torch.tensor(1.0, device=z_value.device, dtype=z_value.dtype), + ) + return v_uncond + scale * (v_cond - v_uncond) + + def _run_sampler( + self, + latents: torch.Tensor, + class_labels: torch.Tensor, + class_null: torch.Tensor, + num_inference_steps: int, + do_classifier_free_guidance: bool, + guidance_scale: float, + guidance_interval_min: float, + guidance_interval_max: float, + sampling_method: str, + ) -> torch.Tensor: + device = latents.device + self.scheduler.set_timesteps(num_inference_steps, device=device, solver=sampling_method) + timesteps = self.scheduler.timesteps + + for i in self.progress_bar(range(num_inference_steps - 1)): + t = timesteps[i] + t_next = timesteps[i + 1] + v = self._predict_velocity( + latents, + t, + class_labels, + class_null, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + ) + + if sampling_method == "heun": + latents_euler = latents + (t_next - t) * v + v_next = self._predict_velocity( + latents_euler, + t_next, + class_labels, + class_null, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + ) + latents = self.scheduler.step(v, t, latents, model_output_next=v_next).prev_sample + else: + latents = self.scheduler.step(v, t, latents).prev_sample + + t = timesteps[-2] + t_next = timesteps[-1] + v = self._predict_velocity( + latents, + t, + class_labels, + class_null, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + ) + return latents + (t_next - t) * v + + @torch.inference_mode() + def __call__( + self, + class_labels: Union[int, str, List[Union[int, str]]], + guidance_scale: Optional[float] = None, + guidance_interval_min: float = 0.1, + guidance_interval_max: float = 1.0, + noise_scale: Optional[float] = None, + t_eps: Optional[float] = None, + sampling_method: Optional[str] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Generate class-conditional images. + + Args: + class_labels (`int`, `str`, `list[int]`, or `list[str]`): + ImageNet class indices or human-readable label strings (English or Chinese). + guidance_scale (`float`, *optional*): + Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`. + guidance_interval_min (`float`, defaults to `0.1`): + Lower bound of the CFG interval in flow time `t in [0, 1]`. + guidance_interval_max (`float`, defaults to `1.0`): + Upper bound of the CFG interval in flow time. + noise_scale (`float`, *optional*): + Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default). + t_eps (`float`, *optional*): + Epsilon clamp for the `1 - t` denominator (scheduler config by default). + sampling_method (`str`, *optional*): + `"heun"` or `"euler"`. Defaults to the scheduler config (`heun`). + generator (`torch.Generator`, *optional*): + RNG for reproducibility. + num_inference_steps (`int`, defaults to `50`): + Number of solver steps (at least 2). + output_type (`str`, *optional*, defaults to `"pil"`): + `"pil"`, `"np"`, or `"pt"`. + return_dict (`bool`, *optional*, defaults to `True`): + Return [`ImagePipelineOutput`] if True. + """ + solver = sampling_method or self.scheduler.config.solver + if solver not in {"heun", "euler"}: + raise ValueError("sampling_method must be one of: 'heun', 'euler'.") + if num_inference_steps < 2: + raise ValueError("num_inference_steps must be >= 2.") + + if t_eps is not None: + self.scheduler.register_to_config(t_eps=t_eps) + + class_label_ids = self._normalize_class_labels(class_labels) + do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0 + + batch_size = len(class_label_ids) + image_size = int(self.transformer.config.sample_size) + channels = int(self.transformer.config.in_channels) + null_class_val = int(self.transformer.config.num_classes) + + if guidance_scale is None: + guidance_scale = 1.0 + if noise_scale is None: + noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(image_size, 1.0) + + latents = ( + randn_tensor( + shape=(batch_size, channels, image_size, image_size), + generator=generator, + device=self._execution_device, + dtype=self.transformer.dtype, + ) + * noise_scale + ) + + class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1) + class_labels_t = class_labels_t.clamp(0, null_class_val - 1) + class_null = torch.full_like(class_labels_t, null_class_val) + + latents = self._run_sampler( + latents, + class_labels_t, + class_null, + num_inference_steps, + do_classifier_free_guidance, + guidance_scale, + guidance_interval_min, + guidance_interval_max, + solver, + ) + + images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu() + if output_type == "pt": + images = images_pt + elif output_type == "np": + images = images_pt.permute(0, 2, 3, 1).numpy() + else: + images = self.numpy_to_pil(images_pt.permute(0, 2, 3, 1).numpy()) + + self.maybe_free_model_hooks() + + if not return_dict: + return (images,) + return ImagePipelineOutput(images=images) diff --git a/JiT-L-32/scheduler/scheduler_config.json b/JiT-L-32/scheduler/scheduler_config.json new file mode 100644 index 0000000000000000000000000000000000000000..bb871ad8071d8be4699f5246288de0a17963a5c4 --- /dev/null +++ b/JiT-L-32/scheduler/scheduler_config.json @@ -0,0 +1,7 @@ +{ + "_class_name": "JiTScheduler", + "_diffusers_version": "0.36.0", + "num_train_timesteps": 1000, + "t_eps": 0.05, + "solver": "heun" +} diff --git a/JiT-L-32/scheduler/scheduling_jit.py b/JiT-L-32/scheduler/scheduling_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5c55890f3446c190ca847f204264b4b8cbbbbb --- /dev/null +++ b/JiT-L-32/scheduler/scheduling_jit.py @@ -0,0 +1,161 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput + + +@dataclass +class JiTSchedulerOutput(BaseOutput): + """ + Output class for the JiT scheduler's `step` function. + + Args: + prev_sample (`torch.Tensor`): + Updated sample after one solver step along the JiT flow-time grid. + """ + + prev_sample: torch.Tensor + + +class JiTScheduler(SchedulerMixin, ConfigMixin): + """ + Manual flow-matching scheduler for JiT checkpoints. + + Uses a linear flow-time grid `t in [0, 1]` (increasing), matching the official JiT + sampler. Velocity is `v = (x_pred - z) / (1 - t)`; integration is explicit Euler or + Heun along that grid. + """ + + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + t_eps: float = 5e-2, + solver: str = "heun", + ): + if solver not in {"heun", "euler"}: + raise ValueError("solver must be one of: 'heun', 'euler'.") + self.timesteps: Optional[torch.Tensor] = None + self.sigmas: Optional[List[float]] = None + self.num_inference_steps: Optional[int] = None + self._step_index: Optional[int] = None + + @property + def init_noise_sigma(self) -> float: + return 1.0 + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device, None] = None, + solver: Optional[str] = None, + ) -> None: + if num_inference_steps < 2: + raise ValueError("num_inference_steps must be >= 2.") + + self.num_inference_steps = num_inference_steps + self.timesteps = torch.linspace( + 0.0, + 1.0, + num_inference_steps + 1, + device=device, + dtype=torch.float32, + ) + sigma_grid = torch.linspace(0.0, 1.0, num_inference_steps, device=device, dtype=torch.float32) + self.sigmas = (1.0 - sigma_grid).tolist() + self._step_index = 0 + if solver is not None: + self.register_to_config(solver=solver) + + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + del timestep + return sample + + def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int: + if self._step_index is not None: + return self._step_index + if self.timesteps is None: + raise ValueError("Call `set_timesteps` before `step`.") + if timestep is None: + return 0 + t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0]) + matches = (self.timesteps - t_value).abs() < 1e-6 + if matches.any(): + return int(matches.nonzero(as_tuple=False)[0].item()) + return 0 + + def step( + self, + model_output: torch.Tensor, + timestep: Union[float, torch.Tensor, None], + sample: torch.Tensor, + model_output_next: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[JiTSchedulerOutput, Tuple[torch.Tensor]]: + """ + Integrate one step on the linear `t` grid. + + Args: + model_output (`torch.Tensor`): + Velocity `v = (x_pred - z) / (1 - t)` at the current time. + timestep (`float` or `torch.Tensor`, *optional*): + Current flow time `t`. When omitted, uses the internal step index. + sample (`torch.Tensor`): + Current noisy latent `z`. + model_output_next (`torch.Tensor`, *optional*): + Velocity at `t_next` (required for Heun intermediate steps). + """ + if self.timesteps is None: + raise ValueError("Call `set_timesteps` before `step`.") + + step_index = self._resolve_step_index(timestep) + if step_index >= len(self.timesteps) - 1: + raise ValueError("Scheduler has already reached the final timestep.") + + t = self.timesteps[step_index] + t_next = self.timesteps[step_index + 1] + dt = t_next - t + + if self.config.solver == "heun" and model_output_next is not None: + prev_sample = sample + dt * 0.5 * (model_output + model_output_next) + else: + prev_sample = sample + dt * model_output + + self._step_index = step_index + 1 + + if not return_dict: + return (prev_sample,) + return JiTSchedulerOutput(prev_sample=prev_sample) + + def velocity_from_prediction( + self, + sample: torch.Tensor, + x_pred: torch.Tensor, + timestep: Union[float, torch.Tensor], + ) -> torch.Tensor: + """Compute JiT velocity `v = (x_pred - z) / (1 - t)` with denominator clamp.""" + t = torch.as_tensor(timestep, device=sample.device, dtype=sample.dtype) + while t.ndim < sample.ndim: + t = t.unsqueeze(-1) + denom = (1.0 - t).clamp_min(self.config.t_eps) + return (x_pred - sample) / denom diff --git a/JiT-L-32/transformer/config.json b/JiT-L-32/transformer/config.json new file mode 100644 index 0000000000000000000000000000000000000000..3e8b6cf5b2811fc93bc581a5b5c07f2dc53033b8 --- /dev/null +++ b/JiT-L-32/transformer/config.json @@ -0,0 +1,18 @@ +{ + "_class_name": "JiTTransformer2DModel", + "_diffusers_version": "0.36.0", + "attention_dropout": 0.0, + "bottleneck_dim": 128, + "dropout": 0.0, + "hidden_size": 1024, + "in_channels": 3, + "in_context_len": 32, + "in_context_start": 8, + "mlp_ratio": 4.0, + "norm_eps": 1e-06, + "num_attention_heads": 16, + "num_classes": 1000, + "num_layers": 24, + "patch_size": 32, + "sample_size": 512 +} diff --git a/JiT-L-32/transformer/diffusion_pytorch_model.safetensors b/JiT-L-32/transformer/diffusion_pytorch_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..6f2359fc7995802c7d967cf572380f4bf3b75b5c --- /dev/null +++ b/JiT-L-32/transformer/diffusion_pytorch_model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:121d3917ab50ad034295646734eb9b898167f19419dd65d22946f38c7d183266 +size 1847219704 diff --git a/JiT-L-32/transformer/jit_transformer_2d.py b/JiT-L-32/transformer/jit_transformer_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..3af0b8632931f4d42d78f8f9ced62d868e070e43 --- /dev/null +++ b/JiT-L-32/transformer/jit_transformer_2d.py @@ -0,0 +1,500 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import RMSNorm +from diffusers.utils import logging + +logger = logging.get_logger(__name__) + + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = {len(t.shape) for t in tensors} + if len(shape_lens) != 1: + raise ValueError("tensors must all have the same number of dimensions") + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*(list(t.shape) for t in tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + + if not all(len(set(t[1])) <= 2 for t in expandable_dims): + raise ValueError("invalid dimensions for broadcastable concatenation") + + max_dims = [(t[0], max(t[1])) for t in expandable_dims] + expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims] + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*(t[1] for t in expanded_dims))) + tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)] + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = x.view(*x.shape[:-1], x.shape[-1] // 2, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return x.view(*x.shape[:-2], -1) + + +class JiTRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + pt_seq_len=16, + ft_seq_len=None, + custom_freqs=None, + theta=10000, + num_cls_token=0, + ): + super().__init__() + if custom_freqs is not None: + freqs = custom_freqs + else: + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs = torch.einsum("..., f -> ... f", t, freqs) + freqs = freqs.repeat_interleave(2, dim=-1) + freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) + + if num_cls_token > 0: + freqs_flat = freqs.view(-1, freqs.shape[-1]) # [N_img, D] + cos_img = freqs_flat.cos() + sin_img = freqs_flat.sin() + + # prepend in-context cls token + _, D = cos_img.shape + cos_pad = torch.ones(num_cls_token, D, dtype=cos_img.dtype) + sin_pad = torch.zeros(num_cls_token, D, dtype=sin_img.dtype) + + self.register_buffer("freqs_cos", torch.cat([cos_pad, cos_img], dim=0), persistent=False) + self.register_buffer("freqs_sin", torch.cat([sin_pad, sin_img], dim=0), persistent=False) + else: + self.register_buffer("freqs_cos", freqs.cos().view(-1, freqs.shape[-1]), persistent=False) + self.register_buffer("freqs_sin", freqs.sin().view(-1, freqs.shape[-1]), persistent=False) + + def forward(self, t): + # Applied on (batch, seq_len, heads, head_dim) tensors from attention. + seq_len = t.shape[1] + freqs_cos = self.freqs_cos[:seq_len].to(t.dtype) + freqs_sin = self.freqs_sin[:seq_len].to(t.dtype) + + return t * freqs_cos[:, None, :] + rotate_half(t) * freqs_sin[:, None, :] + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class JiTPatchEmbed(nn.Module): + """Image to Patch Embedding with Bottleneck""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, pca_dim=768, embed_dim=768, bias=True): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + + self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False) + self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias) + + def forward(self, x): + x = self.proj2(self.proj1(x)).flatten(2).transpose(1, 2) + return x + + +class JiTTimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t, dtype=None): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + if dtype is not None: + t_freq = t_freq.to(dtype=dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class JiTLabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. + """ + + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes + 1, hidden_size) + self.num_classes = num_classes + + def forward(self, labels): + embeddings = self.embedding_table(labels) + return embeddings + + +class JiTAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0.0, proj_drop=0.0, eps=1e-6): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.q_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + self.k_norm = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rope=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = self.q_norm(q) + k = self.k_norm(k) + + if rope is not None: + q = q.transpose(1, 2) + k = k.transpose(1, 2) + q = rope(q) + k = rope(k) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + dropout_p = self.attn_drop if self.training else 0.0 + x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class JiTSwiGLUFFN(nn.Module): + def __init__(self, dim: int, hidden_dim: int, drop=0.0, bias=True) -> None: + super().__init__() + hidden_dim = int(hidden_dim * 2 / 3) + self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias) + self.w3 = nn.Linear(hidden_dim, dim, bias=bias) + self.ffn_dropout = nn.Dropout(drop) + + def forward(self, x): + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(self.ffn_dropout(hidden)) + + +class JiTBlock(nn.Module): + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0, eps=1e-6): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=eps) + self.attn = JiTAttention( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + qk_norm=True, + attn_drop=attn_drop, + proj_drop=proj_drop, + eps=eps, + ) + self.norm2 = RMSNorm(hidden_size, eps=eps) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = JiTSwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop) + + self.act = nn.SiLU() + self.adaLN_modulation = nn.Linear(hidden_size, 6 * hidden_size, bias=True) + + def forward(self, x, c, feat_rope=None): + # Apply activation + c = self.act(c) + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + + # Attention block + norm_x = self.norm1(x) + modulated_x = modulate(norm_x, shift_msa, scale_msa) + attn_out = self.attn(modulated_x, rope=feat_rope) + x = x + gate_msa.unsqueeze(1) * attn_out + + # MLP block + norm_x = self.norm2(x) + modulated_x = modulate(norm_x, shift_mlp, scale_mlp) + mlp_out = self.mlp(modulated_x) + x = x + gate_mlp.unsqueeze(1) * mlp_out + + return x + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0) + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}") + + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) + emb = np.concatenate([emb_h, emb_w], axis=1) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + if embed_dim % 2 != 0: + raise ValueError(f"embed_dim must be divisible by 2, but got {embed_dim}") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega + + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) + + emb_sin = np.sin(out) + emb_cos = np.cos(out) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) + return emb + + +class JiTTransformer2DModel(ModelMixin, ConfigMixin): + r""" + A 2D Transformer for pixel-space class-conditional generation with JiT + ([Back to Basics: Let Denoising Generative Models Denoise](https://arxiv.org/abs/2511.13720)). + + Parameters: + sample_size (`int`, defaults to `256`): + Input image resolution (height and width). + patch_size (`int`, defaults to `16`): + Patch size for the bottleneck patch embedder. + in_channels (`int`, defaults to `3`): + Number of input image channels. + hidden_size (`int`, defaults to `768`): + Transformer hidden dimension. + num_layers (`int`, defaults to `12`): + Number of JiT transformer blocks. + num_attention_heads (`int`, defaults to `12`): + Number of attention heads per block. + mlp_ratio (`float`, defaults to `4.0`): + MLP hidden dimension multiplier. + attention_dropout (`float`, defaults to `0.0`): + Attention dropout in the middle quarter of blocks. + dropout (`float`, defaults to `0.0`): + Projection dropout in the middle quarter of blocks. + num_classes (`int`, defaults to `1000`): + Number of class labels (null label uses index `num_classes` for CFG). + bottleneck_dim (`int`, defaults to `128`): + PCA bottleneck dimension in the patch embedder. + in_context_len (`int`, defaults to `32`): + Number of in-context class tokens prepended mid-network. + in_context_start (`int`, defaults to `4`): + Block index at which in-context tokens are inserted. + norm_eps (`float`, defaults to `1e-6`): + Epsilon for RMSNorm layers. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + + @register_to_config + def __init__( + self, + sample_size: int = 256, + patch_size: int = 16, + in_channels: int = 3, + hidden_size: int = 768, + num_layers: int = 12, + num_attention_heads: int = 12, + mlp_ratio: float = 4.0, + attention_dropout: float = 0.0, + dropout: float = 0.0, + num_classes: int = 1000, + bottleneck_dim: int = 128, + in_context_len: int = 32, + in_context_start: int = 4, + norm_eps: float = 1e-6, + ): + super().__init__() + self.sample_size = sample_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.in_context_len = in_context_len + self.in_context_start = in_context_start + self.norm_eps = norm_eps + self.gradient_checkpointing = False + + # Time and Class Embedding + self.t_embedder = JiTTimestepEmbedder(hidden_size) + self.y_embedder = JiTLabelEmbedder(num_classes, hidden_size) + + # Patch Embedding + self.x_embedder = JiTPatchEmbed( + img_size=sample_size, + patch_size=patch_size, + in_chans=in_channels, + pca_dim=bottleneck_dim, + embed_dim=hidden_size, + bias=True, + ) + + # Positional Embedding (Fixed Sin-Cos) + num_patches = self.x_embedder.num_patches + pos_embed = get_2d_sincos_pos_embed(hidden_size, int(num_patches**0.5)) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True) + + # In-context Embedding + if self.in_context_len > 0: + self.in_context_posemb = nn.Parameter(torch.zeros(1, self.in_context_len, hidden_size)) + + # RoPE + half_head_dim = hidden_size // num_attention_heads // 2 + hw_seq_len = sample_size // patch_size + self.feat_rope = JiTRotaryEmbedding(dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=0) + self.feat_rope_incontext = JiTRotaryEmbedding( + dim=half_head_dim, pt_seq_len=hw_seq_len, num_cls_token=self.in_context_len + ) + + # Blocks + self.blocks = nn.ModuleList( + [ + JiTBlock( + hidden_size, + num_attention_heads, + mlp_ratio=mlp_ratio, + attn_drop=attention_dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0, + proj_drop=dropout if (num_layers // 4 * 3 > i >= num_layers // 4) else 0.0, + eps=norm_eps, + ) + for i in range(num_layers) + ] + ) + + # Final Layer + self.norm_final = RMSNorm(hidden_size, eps=norm_eps) + self.linear_final = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True) + self.act_final = nn.SiLU() + self.adaLN_modulation_final = nn.Linear(hidden_size, 2 * hidden_size, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + class_labels: torch.LongTensor, + return_dict: bool = True, + ): + + t_emb = self.t_embedder(timestep, dtype=hidden_states.dtype) + y_emb = self.y_embedder(class_labels) + + # Ensure embeddings match hidden_states dtype + y_emb = y_emb.to(dtype=hidden_states.dtype) + + c = t_emb + y_emb + + # Patch Embed + x = self.x_embedder(hidden_states) + x = x + self.pos_embed.to(x.dtype) + + # Blocks + for i, block in enumerate(self.blocks): + if self.in_context_len > 0 and i == self.in_context_start: + in_context_tokens = y_emb.unsqueeze(1).repeat(1, self.in_context_len, 1) + in_context_tokens = in_context_tokens + self.in_context_posemb.to(in_context_tokens.dtype) + x = torch.cat([in_context_tokens, x], dim=1) + + rope = self.feat_rope if i < self.in_context_start else self.feat_rope_incontext + + if self.training and self.gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + block, + x, + c, + rope, + use_reentrant=False, + ) + else: + x = block(x, c, feat_rope=rope) + + # Slice off in-context tokens + if self.in_context_len > 0: + x = x[:, self.in_context_len :] + + # Final Layer + c = self.act_final(c) + shift, scale = self.adaLN_modulation_final(c).chunk(2, dim=1) + + x = modulate(self.norm_final(x), shift, scale) + x = self.linear_final(x) + + # Unpatchify + h = w = int(x.shape[1] ** 0.5) + x = x.reshape(shape=(x.shape[0], h, w, self.patch_size, self.patch_size, self.out_channels)) + x = torch.einsum("nhwpqc->nchpwq", x) + output = x.reshape(shape=(x.shape[0], self.out_channels, h * self.patch_size, w * self.patch_size)) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/README.md b/README.md index 8bea21934e1321d732e79d6706a4138f0504fb9e..722c43266b7f987979f51aec70ffabb1881ffada 100644 --- a/README.md +++ b/README.md @@ -14,77 +14,67 @@ language: - en --- -# JiT-H/32 (Diffusers) +# JiT-diffusers -This repository is self-contained: model weights and a custom `diffusers` pipeline (`JiTPipeline`) are both included, so no external code repo is required. +Native diffusers implementation of **JiT** (Just image Transformer). Each variant folder is self-contained: -## Available Checkpoints (All 6) +- `pipeline.py` — `JiTPipeline` +- `scheduler/scheduling_jit.py` — `JiTScheduler` (linear `t in [0, 1]`, Heun/Euler) +- `transformer/jit_transformer_2d.py` — `JiTTransformer2DModel` -The JiT paper reports six ImageNet checkpoints across 256 and 512 resolutions. Use the following relative paths with `JiTPipeline.from_pretrained(...)`. +Shared ImageNet-1k labels live in [`labels/`](labels/) at the repo root (not duplicated per variant). -| Checkpoint | Relative path | Resolution | Pre-trained dataset | Recommended CFG | Recommended interval | Recommended noise_scale | FID-50K | -|---|---|---|---|---:|---|---:|---:| -| JiT-B/16 | `./JiT-B-16` | 256x256 | ImageNet 256x256 | 3.0 | `[0.1, 1.0]` | 1.0 | 3.66 | -| JiT-L/16 | `./JiT-L-16` | 256x256 | ImageNet 256x256 | 2.4 | `[0.1, 1.0]` | 1.0 | 2.36 | -| JiT-H/16 | `./JiT-H-16` | 256x256 | ImageNet 256x256 | 2.2 | `[0.1, 1.0]` | 1.0 | 1.86 | -| JiT-B/32 | `./JiT-B-32` | 512x512 | ImageNet 512x512 | 3.0 | `[0.1, 1.0]` | 2.0 | 4.02 | -| JiT-L/32 | `./JiT-L-32` | 512x512 | ImageNet 512x512 | 2.5 | `[0.1, 1.0]` | 2.0 | 2.53 | -| JiT-H/32 | `./JiT-H-32` | 512x512 | ImageNet 512x512 | 2.3 | `[0.1, 1.0]` | 2.0 | 1.94 | +No separate `jit_diffusers` package; only PyPI `diffusers` plus local custom code in the variant directory. -Source: [Back to Basics: Let Denoising Generative Models Denoise (arXiv:2511.13720)](https://arxiv.org/html/2511.13720). +## Available checkpoints -## Demo Image +| Checkpoint | Path | Resolution | Recommended CFG | +|---|---|---|---| +| JiT-B/16 | `./JiT-B-16` | 256×256 | 3.0 | +| JiT-L/16 | `./JiT-L-16` | 256×256 | 2.4 | +| JiT-H/16 | `./JiT-H-16` | 256×256 | 2.2 | +| JiT-B/32 | `./JiT-B-32` | 512×512 | 3.0 | +| JiT-L/32 | `./JiT-L-32` | 512×512 | 2.5 | +| JiT-H/32 | `./JiT-H-32` | 512×512 | 2.3 | -![JiT-H/32 test inference](demo_images/jit_h32_test_inference.png) +## ImageNet class labels -## One-Stop Diffusers Inference +| File | Direction | Format | +|---|---|---| +| `labels/id2label_en.json` | id → English | comma-separated synonyms, e.g. `"207": "golden retriever"` | +| `labels/id2label_cn.json` | id → Chinese | comma-separated synonyms, e.g. `"207": "金毛猎犬"` | + +- `pipe.id2label` / `pipe.id2label_cn` — inspect id → label correspondence +- `pipe.labels` / `pipe.labels_cn` — reverse maps (synonym → id), sorted for browsing +- `pipe.get_label_ids("golden retriever")` or `pipe.get_label_ids("金毛猎犬", lang="cn")` +- `pipe(class_labels="golden retriever", ...)` — string labels resolved automatically + +## Inference ```python -from pathlib import Path -import sys +from diffusers import DiffusionPipeline import torch -repo_dir = Path(".").resolve() -sys.path.insert(0, str(repo_dir)) -from jit_diffusers import JiTPipeline +pipe = DiffusionPipeline.from_pretrained( + "./JiT-H-32", + trust_remote_code=True, +) +pipe.to("cuda") +pipe.transformer.to(dtype=torch.bfloat16) -device = "cuda" if torch.cuda.is_available() else "cpu" -pipe = JiTPipeline.from_pretrained("./JiT-H-32").to(device) -pipe.transformer = pipe.transformer.to(device=device, dtype=torch.bfloat16 if device == "cuda" else torch.float32) -pipe.transformer.eval() +# Numeric or human-readable labels +print(pipe.id2label[207]) +print(pipe.get_label_ids("golden retriever")) -generator = torch.Generator(device=device).manual_seed(42) -output = pipe( - class_labels=[207], +generator = torch.Generator(device="cuda").manual_seed(42) +images = pipe( + class_labels="golden retriever", num_inference_steps=50, guidance_scale=2.3, - guidance_interval_min=0.1, - guidance_interval_max=1.0, - noise_scale=2.0, - t_eps=5e-2, sampling_method="heun", generator=generator, - output_type="pil", -) -image = output.images[0] -output_path = Path("./demo_images/jit_h32_test_inference.png") -output_path.parent.mkdir(parents=True, exist_ok=True) -image.save(output_path) -print(f"Saved image to: {output_path}") +).images +images[0].save("output.png") ``` -## Ready-to-Run Commands (All 6 Checkpoints) - -Run these from this repository root (`models/BiliSakura/JiT-diffusers`). - -```bash -# 256x256 checkpoints -python run_jit_diffusers_inference.py --model_path ./JiT-B-16 --output_path ./demo_images/jit_b16_test_inference.png --class_label 207 --seed 42 --steps 50 --cfg 3.0 --interval_min 0.1 --interval_max 1.0 --noise_scale 1.0 --t_eps 5e-2 --solver heun -python run_jit_diffusers_inference.py --model_path ./JiT-L-16 --output_path ./demo_images/jit_l16_test_inference.png --class_label 207 --seed 42 --steps 50 --cfg 2.4 --interval_min 0.1 --interval_max 1.0 --noise_scale 1.0 --t_eps 5e-2 --solver heun -python run_jit_diffusers_inference.py --model_path ./JiT-H-16 --output_path ./demo_images/jit_h16_test_inference.png --class_label 207 --seed 42 --steps 50 --cfg 2.2 --interval_min 0.1 --interval_max 1.0 --noise_scale 1.0 --t_eps 5e-2 --solver heun - -# 512x512 checkpoints -python run_jit_diffusers_inference.py --model_path ./JiT-B-32 --output_path ./demo_images/jit_b32_test_inference.png --class_label 207 --seed 42 --steps 50 --cfg 3.0 --interval_min 0.1 --interval_max 1.0 --noise_scale 2.0 --t_eps 5e-2 --solver heun -python run_jit_diffusers_inference.py --model_path ./JiT-L-32 --output_path ./demo_images/jit_l32_test_inference.png --class_label 207 --seed 42 --steps 50 --cfg 2.5 --interval_min 0.1 --interval_max 1.0 --noise_scale 2.0 --t_eps 5e-2 --solver heun -python run_jit_diffusers_inference.py --model_path ./JiT-H-32 --output_path ./demo_images/jit_h32_test_inference.png --class_label 207 --seed 42 --steps 50 --cfg 2.3 --interval_min 0.1 --interval_max 1.0 --noise_scale 2.0 --t_eps 5e-2 --solver heun -``` \ No newline at end of file +Load a **variant subfolder** (e.g. `./JiT-H-32`), not the repo root. diff --git a/demo.png b/demo.png index 2ac9cc223994e0b3a7f0a785e40728e39f08f904..1d7d706c338067fcaea523a4d5dd835d6a275f97 100644 --- a/demo.png +++ b/demo.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d595ae2a4d665119949ee1c3930fd7a24befd51d4d4b1932a1a4c7e9e180f899 -size 490449 +oid sha256:f5fdbd0300f895de7642229d1294aff74facd75c0bb4c4a01efa8c75b14b6fc4 +size 470060 diff --git a/demo_images/jit_h32_final_test.png b/demo_images/jit_h32_final_test.png new file mode 100644 index 0000000000000000000000000000000000000000..fd6bd5fccdb1b65beb436907cb7d7cc9d22d207d --- /dev/null +++ b/demo_images/jit_h32_final_test.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bc6804e8a82ad4873a6e9c9e2cf31a7ab901516184cca29955d12b59a45a8920 +size 470455 diff --git a/demo_images/jit_h32_test_inference.png b/demo_images/jit_h32_test_inference.png index 2ac9cc223994e0b3a7f0a785e40728e39f08f904..6c43bd72a0f135c595236086e1ffd9a54821d272 100644 --- a/demo_images/jit_h32_test_inference.png +++ b/demo_images/jit_h32_test_inference.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d595ae2a4d665119949ee1c3930fd7a24befd51d4d4b1932a1a4c7e9e180f899 -size 490449 +oid sha256:c3a6657c5ac1b7e50dec1dfff0fc02759b1b0e1d5cff75cb3db45af1391fef73 +size 470050 diff --git a/labels/__pycache__/imagenet_labels.cpython-312.pyc b/labels/__pycache__/imagenet_labels.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d526255261f1b4e1f0476d78d5da630531340a0 Binary files /dev/null and b/labels/__pycache__/imagenet_labels.cpython-312.pyc differ diff --git a/labels/id2label_cn.json b/labels/id2label_cn.json new file mode 100644 index 0000000000000000000000000000000000000000..dc1d2796598f87fb54d88c006953de2756069934 --- /dev/null +++ b/labels/id2label_cn.json @@ -0,0 +1,1002 @@ +{ + "0": "丁鲷", + "1": "金鱼", + "2": "大白鲨", + "3": "虎鲨", + "4": "锤头鲨", + "5": "电鳐", + "6": "黄貂鱼", + "7": "公鸡", + "8": "母鸡", + "9": "鸵鸟", + "10": "燕雀", + "11": "金翅雀", + "12": "家朱雀", + "13": "灯芯草雀", + "14": "靛蓝雀,靛蓝鸟", + "15": "蓝鹀", + "16": "夜莺", + "17": "松鸦", + "18": "喜鹊", + "19": "山雀", + "20": "河鸟", + "21": "鸢(猛禽)", + "22": "秃头鹰", + "23": "秃鹫", + "24": "大灰猫头鹰", + "25": "欧洲火蝾螈", + "26": "普通蝾螈", + "27": "水蜥", + "28": "斑点蝾螈", + "29": "蝾螈,泥狗", + "30": "牛蛙", + "31": "树蛙", + "32": "尾蛙,铃蟾蜍,肋蟾蜍,尾蟾蜍", + "33": "红海龟", + "34": "皮革龟", + "35": "泥龟", + "36": "淡水龟", + "37": "箱龟", + "38": "带状壁虎", + "39": "普通鬣蜥", + "40": "美国变色龙", + "41": "鞭尾蜥蜴", + "42": "飞龙科蜥蜴", + "43": "褶边蜥蜴", + "44": "鳄鱼蜥蜴", + "45": "毒蜥", + "46": "绿蜥蜴", + "47": "非洲变色龙", + "48": "科莫多蜥蜴", + "49": "非洲鳄,尼罗河鳄鱼", + "50": "美国鳄鱼,鳄鱼", + "51": "三角龙", + "52": "雷蛇,蠕虫蛇", + "53": "环蛇,环颈蛇", + "54": "希腊蛇", + "55": "绿蛇,草蛇", + "56": "国王蛇", + "57": "袜带蛇,草蛇", + "58": "水蛇", + "59": "藤蛇", + "60": "夜蛇", + "61": "大蟒蛇", + "62": "岩石蟒蛇,岩蛇,蟒蛇", + "63": "印度眼镜蛇", + "64": "绿曼巴", + "65": "海蛇", + "66": "角腹蛇", + "67": "菱纹响尾蛇", + "68": "角响尾蛇", + "69": "三叶虫", + "70": "盲蜘蛛", + "71": "蝎子", + "72": "黑金花园蜘蛛", + "73": "谷仓蜘蛛", + "74": "花园蜘蛛", + "75": "黑寡妇蜘蛛", + "76": "狼蛛", + "77": "狼蜘蛛,狩猎蜘蛛", + "78": "壁虱", + "79": "蜈蚣", + "80": "黑松鸡", + "81": "松鸡,雷鸟", + "82": "披肩鸡,披肩榛鸡", + "83": "草原鸡,草原松鸡", + "84": "孔雀", + "85": "鹌鹑", + "86": "鹧鸪", + "87": "非洲灰鹦鹉", + "88": "金刚鹦鹉", + "89": "硫冠鹦鹉", + "90": "短尾鹦鹉", + "91": "褐翅鸦鹃", + "92": "蜜蜂", + "93": "犀鸟", + "94": "蜂鸟", + "95": "鹟䴕", + "96": "犀鸟", + "97": "野鸭", + "98": "红胸秋沙鸭", + "99": "鹅", + "100": "黑天鹅", + "101": "大象", + "102": "针鼹鼠", + "103": "鸭嘴兽", + "104": "沙袋鼠", + "105": "考拉,考拉熊", + "106": "袋熊", + "107": "水母", + "108": "海葵", + "109": "脑珊瑚", + "110": "扁形虫扁虫", + "111": "线虫,蛔虫", + "112": "海螺", + "113": "蜗牛", + "114": "鼻涕虫", + "115": "海参", + "116": "石鳖", + "117": "鹦鹉螺", + "118": "珍宝蟹", + "119": "石蟹", + "120": "招潮蟹", + "121": "帝王蟹,阿拉斯加蟹,阿拉斯加帝王蟹", + "122": "美国龙虾,缅因州龙虾", + "123": "大螯虾", + "124": "小龙虾", + "125": "寄居蟹", + "126": "等足目动物(明虾和螃蟹近亲)", + "127": "白鹳", + "128": "黑鹳", + "129": "鹭", + "130": "火烈鸟", + "131": "小蓝鹭", + "132": "美国鹭,大白鹭", + "133": "麻鸦", + "134": "鹤", + "135": "秧鹤", + "136": "欧洲水鸡,紫水鸡", + "137": "沼泽泥母鸡,水母鸡", + "138": "鸨", + "139": "红翻石鹬", + "140": "红背鹬,黑腹滨鹬", + "141": "红脚鹬", + "142": "半蹼鹬", + "143": "蛎鹬", + "144": "鹈鹕", + "145": "国王企鹅", + "146": "信天翁,大海鸟", + "147": "灰鲸", + "148": "杀人鲸,逆戟鲸,虎鲸", + "149": "海牛", + "150": "海狮", + "151": "奇瓦瓦", + "152": "日本猎犬", + "153": "马尔济斯犬", + "154": "狮子狗", + "155": "西施犬", + "156": "布莱尼姆猎犬", + "157": "巴比狗", + "158": "玩具犬", + "159": "罗得西亚长背猎狗", + "160": "阿富汗猎犬", + "161": "猎犬", + "162": "比格犬,猎兔犬", + "163": "侦探犬", + "164": "蓝色快狗", + "165": "黑褐猎浣熊犬", + "166": "沃克猎犬", + "167": "英国猎狐犬", + "168": "美洲赤狗", + "169": "俄罗斯猎狼犬", + "170": "爱尔兰猎狼犬", + "171": "意大利灰狗", + "172": "惠比特犬", + "173": "依比沙猎犬", + "174": "挪威猎犬", + "175": "奥达猎犬,水獭猎犬", + "176": "沙克犬,瞪羚猎犬", + "177": "苏格兰猎鹿犬,猎鹿犬", + "178": "威玛猎犬", + "179": "斯塔福德郡牛头梗,斯塔福德郡斗牛梗", + "180": "美国斯塔福德郡梗,美国比特斗牛梗,斗牛梗", + "181": "贝德灵顿梗", + "182": "边境梗", + "183": "凯丽蓝梗", + "184": "爱尔兰梗", + "185": "诺福克梗", + "186": "诺维奇梗", + "187": "约克郡梗", + "188": "刚毛猎狐梗", + "189": "莱克兰梗", + "190": "锡利哈姆梗", + "191": "艾尔谷犬", + "192": "凯恩梗", + "193": "澳大利亚梗", + "194": "丹迪丁蒙梗", + "195": "波士顿梗", + "196": "迷你雪纳瑞犬", + "197": "巨型雪纳瑞犬", + "198": "标准雪纳瑞犬", + "199": "苏格兰梗", + "200": "西藏梗,菊花狗", + "201": "丝毛梗", + "202": "软毛麦色梗", + "203": "西高地白梗", + "204": "拉萨阿普索犬", + "205": "平毛寻回犬", + "206": "卷毛寻回犬", + "207": "金毛猎犬", + "208": "拉布拉多猎犬", + "209": "乞沙比克猎犬", + "210": "德国短毛猎犬", + "211": "维兹拉犬", + "212": "英国谍犬", + "213": "爱尔兰雪达犬,红色猎犬", + "214": "戈登雪达犬", + "215": "布列塔尼犬猎犬", + "216": "黄毛,黄毛猎犬", + "217": "英国史宾格犬", + "218": "威尔士史宾格犬", + "219": "可卡犬,英国可卡犬", + "220": "萨塞克斯猎犬", + "221": "爱尔兰水猎犬", + "222": "哥威斯犬", + "223": "舒柏奇犬", + "224": "比利时牧羊犬", + "225": "马里努阿犬", + "226": "伯瑞犬", + "227": "凯尔皮犬", + "228": "匈牙利牧羊犬", + "229": "老英国牧羊犬", + "230": "喜乐蒂牧羊犬", + "231": "牧羊犬", + "232": "边境牧羊犬", + "233": "法兰德斯牧牛狗", + "234": "罗特韦尔犬", + "235": "德国牧羊犬,德国警犬,阿尔萨斯", + "236": "多伯曼犬,杜宾犬", + "237": "迷你杜宾犬", + "238": "大瑞士山地犬", + "239": "伯恩山犬", + "240": "Appenzeller狗", + "241": "EntleBucher狗", + "242": "拳师狗", + "243": "斗牛獒", + "244": "藏獒", + "245": "法国斗牛犬", + "246": "大丹犬", + "247": "圣伯纳德狗", + "248": "爱斯基摩犬,哈士奇", + "249": "雪橇犬,阿拉斯加爱斯基摩狗", + "250": "哈士奇", + "251": "达尔马提亚,教练车狗", + "252": "狮毛狗", + "253": "巴辛吉狗", + "254": "哈巴狗,狮子狗", + "255": "莱昂贝格狗", + "256": "纽芬兰岛狗", + "257": "大白熊犬", + "258": "萨摩耶犬", + "259": "博美犬", + "260": "松狮,松狮", + "261": "荷兰卷尾狮毛狗", + "262": "布鲁塞尔格林芬犬", + "263": "彭布洛克威尔士科基犬", + "264": "威尔士柯基犬", + "265": "玩具贵宾犬", + "266": "迷你贵宾犬", + "267": "标准贵宾犬", + "268": "墨西哥无毛犬", + "269": "灰狼", + "270": "白狼,北极狼", + "271": "红太狼,鬃狼,犬犬鲁弗斯", + "272": "狼,草原狼,刷狼,郊狼", + "273": "澳洲野狗,澳大利亚野犬", + "274": "豺", + "275": "非洲猎犬,土狼犬", + "276": "鬣狗", + "277": "红狐狸", + "278": "沙狐", + "279": "北极狐狸,白狐狸", + "280": "灰狐狸", + "281": "虎斑猫", + "282": "山猫,虎猫", + "283": "波斯猫", + "284": "暹罗暹罗猫,", + "285": "埃及猫", + "286": "美洲狮,美洲豹", + "287": "猞猁,山猫", + "288": "豹子", + "289": "雪豹", + "290": "美洲虎", + "291": "狮子", + "292": "老虎", + "293": "猎豹", + "294": "棕熊", + "295": "美洲黑熊", + "296": "冰熊,北极熊", + "297": "懒熊", + "298": "猫鼬", + "299": "猫鼬,海猫", + "300": "虎甲虫", + "301": "瓢虫", + "302": "土鳖虫", + "303": "天牛", + "304": "龟甲虫", + "305": "粪甲虫", + "306": "犀牛甲虫", + "307": "象甲", + "308": "苍蝇", + "309": "蜜蜂", + "310": "蚂蚁", + "311": "蚱蜢", + "312": "蟋蟀", + "313": "竹节虫", + "314": "蟑螂", + "315": "螳螂", + "316": "蝉", + "317": "叶蝉", + "318": "草蜻蛉", + "319": "蜻蜓", + "320": "豆娘,蜻蛉", + "321": "优红蛱蝶", + "322": "小环蝴蝶", + "323": "君主蝴蝶,大斑蝶", + "324": "菜粉蝶", + "325": "白蝴蝶", + "326": "灰蝶", + "327": "海星", + "328": "海胆", + "329": "海参,海黄瓜", + "330": "野兔", + "331": "兔", + "332": "安哥拉兔", + "333": "仓鼠", + "334": "刺猬,豪猪,", + "335": "黑松鼠", + "336": "土拨鼠", + "337": "海狸", + "338": "豚鼠,豚鼠", + "339": "栗色马", + "340": "斑马", + "341": "猪", + "342": "野猪", + "343": "疣猪", + "344": "河马", + "345": "牛", + "346": "水牛,亚洲水牛", + "347": "野牛", + "348": "公羊", + "349": "大角羊,洛矶山大角羊", + "350": "山羊", + "351": "狷羚", + "352": "黑斑羚", + "353": "瞪羚", + "354": "阿拉伯单峰骆驼,骆驼", + "355": "羊驼", + "356": "黄鼠狼", + "357": "水貂", + "358": "臭猫", + "359": "黑足鼬", + "360": "水獭", + "361": "臭鼬,木猫", + "362": "獾", + "363": "犰狳", + "364": "树懒", + "365": "猩猩,婆罗洲猩猩", + "366": "大猩猩", + "367": "黑猩猩", + "368": "长臂猿", + "369": "合趾猿长臂猿,合趾猿", + "370": "长尾猴", + "371": "赤猴", + "372": "狒狒", + "373": "恒河猴,猕猴", + "374": "白头叶猴", + "375": "疣猴", + "376": "长鼻猴", + "377": "狨(美洲产小型长尾猴)", + "378": "卷尾猴", + "379": "吼猴", + "380": "伶猴", + "381": "蜘蛛猴", + "382": "松鼠猴", + "383": "马达加斯加环尾狐猴,鼠狐猴", + "384": "大狐猴,马达加斯加大狐猴", + "385": "印度大象,亚洲象", + "386": "非洲象,非洲象", + "387": "小熊猫", + "388": "大熊猫", + "389": "杖鱼", + "390": "鳗鱼", + "391": "银鲑,银鲑鱼", + "392": "三色刺蝶鱼", + "393": "海葵鱼", + "394": "鲟鱼", + "395": "雀鳝", + "396": "狮子鱼", + "397": "河豚", + "398": "算盘", + "399": "长袍", + "400": "学位袍", + "401": "手风琴", + "402": "原声吉他", + "403": "航空母舰", + "404": "客机", + "405": "飞艇", + "406": "祭坛", + "407": "救护车", + "408": "水陆两用车", + "409": "模拟时钟", + "410": "蜂房", + "411": "围裙", + "412": "垃圾桶", + "413": "攻击步枪,枪", + "414": "背包", + "415": "面包店,面包铺,", + "416": "平衡木", + "417": "热气球", + "418": "圆珠笔", + "419": "创可贴", + "420": "班卓琴", + "421": "栏杆,楼梯扶手", + "422": "杠铃", + "423": "理发师的椅子", + "424": "理发店", + "425": "牲口棚", + "426": "晴雨表", + "427": "圆筒", + "428": "园地小车,手推车", + "429": "棒球", + "430": "篮球", + "431": "婴儿床", + "432": "巴松管,低音管", + "433": "游泳帽", + "434": "沐浴毛巾", + "435": "浴缸,澡盆", + "436": "沙滩车,旅行车", + "437": "灯塔", + "438": "高脚杯", + "439": "熊皮高帽", + "440": "啤酒瓶", + "441": "啤酒杯", + "442": "钟塔", + "443": "(小儿用的)围嘴", + "444": "串联自行车,", + "445": "比基尼", + "446": "装订册", + "447": "双筒望远镜", + "448": "鸟舍", + "449": "船库", + "450": "雪橇", + "451": "饰扣式领带", + "452": "阔边女帽", + "453": "书橱", + "454": "书店,书摊", + "455": "瓶盖", + "456": "弓箭", + "457": "蝴蝶结领结", + "458": "铜制牌位", + "459": "奶罩", + "460": "防波堤,海堤", + "461": "铠甲", + "462": "扫帚", + "463": "桶", + "464": "扣环", + "465": "防弹背心", + "466": "动车,子弹头列车", + "467": "肉铺,肉菜市场", + "468": "出租车", + "469": "大锅", + "470": "蜡烛", + "471": "大炮", + "472": "独木舟", + "473": "开瓶器,开罐器", + "474": "开衫", + "475": "车镜", + "476": "旋转木马", + "477": "木匠的工具包,工具包", + "478": "纸箱", + "479": "车轮", + "480": "取款机,自动取款机", + "481": "盒式录音带", + "482": "卡带播放器", + "483": "城堡", + "484": "双体船", + "485": "CD播放器", + "486": "大提琴", + "487": "移动电话,手机", + "488": "铁链", + "489": "围栏", + "490": "链甲", + "491": "电锯,油锯", + "492": "箱子", + "493": "衣柜,洗脸台", + "494": "编钟,钟,锣", + "495": "中国橱柜", + "496": "圣诞袜", + "497": "教堂,教堂建筑", + "498": "电影院,剧场", + "499": "切肉刀,菜刀", + "500": "悬崖屋", + "501": "斗篷", + "502": "木屐,木鞋", + "503": "鸡尾酒调酒器", + "504": "咖啡杯", + "505": "咖啡壶", + "506": "螺旋结构(楼梯)", + "507": "组合锁", + "508": "电脑键盘,键盘", + "509": "糖果,糖果店", + "510": "集装箱船", + "511": "敞篷车", + "512": "开瓶器,瓶螺杆", + "513": "短号,喇叭", + "514": "牛仔靴", + "515": "牛仔帽", + "516": "摇篮", + "517": "起重机", + "518": "头盔", + "519": "板条箱", + "520": "小儿床", + "521": "砂锅", + "522": "槌球", + "523": "拐杖", + "524": "胸甲", + "525": "大坝,堤防", + "526": "书桌", + "527": "台式电脑", + "528": "有线电话", + "529": "尿布湿", + "530": "数字时钟", + "531": "数字手表", + "532": "餐桌板", + "533": "抹布", + "534": "洗碗机,洗碟机", + "535": "盘式制动器", + "536": "码头,船坞,码头设施", + "537": "狗拉雪橇", + "538": "圆顶", + "539": "门垫,垫子", + "540": "钻井平台,海上钻井", + "541": "鼓,乐器,鼓膜", + "542": "鼓槌", + "543": "哑铃", + "544": "荷兰烤箱", + "545": "电风扇,鼓风机", + "546": "电吉他", + "547": "电力机车", + "548": "电视,电视柜", + "549": "信封", + "550": "浓缩咖啡机", + "551": "扑面粉", + "552": "女用长围巾", + "553": "文件,文件柜,档案柜", + "554": "消防船", + "555": "消防车", + "556": "火炉栏", + "557": "旗杆", + "558": "长笛", + "559": "折叠椅", + "560": "橄榄球头盔", + "561": "叉车", + "562": "喷泉", + "563": "钢笔", + "564": "有四根帷柱的床", + "565": "运货车厢", + "566": "圆号,喇叭", + "567": "煎锅", + "568": "裘皮大衣", + "569": "垃圾车", + "570": "防毒面具,呼吸器", + "571": "汽油泵", + "572": "高脚杯", + "573": "卡丁车", + "574": "高尔夫球", + "575": "高尔夫球车", + "576": "狭长小船", + "577": "锣", + "578": "礼服", + "579": "钢琴", + "580": "温室,苗圃", + "581": "散热器格栅", + "582": "杂货店,食品市场", + "583": "断头台", + "584": "小发夹", + "585": "头发喷雾", + "586": "半履带装甲车", + "587": "锤子", + "588": "大篮子", + "589": "手摇鼓风机,吹风机", + "590": "手提电脑", + "591": "手帕", + "592": "硬盘", + "593": "口琴,口风琴", + "594": "竖琴", + "595": "收割机", + "596": "斧头", + "597": "手枪皮套", + "598": "家庭影院", + "599": "蜂窝", + "600": "钩爪", + "601": "衬裙", + "602": "单杠", + "603": "马车", + "604": "沙漏", + "605": "手机,iPad", + "606": "熨斗", + "607": "南瓜灯笼", + "608": "牛仔裤,蓝色牛仔裤", + "609": "吉普车", + "610": "运动衫,T恤", + "611": "拼图", + "612": "人力车", + "613": "操纵杆", + "614": "和服", + "615": "护膝", + "616": "蝴蝶结", + "617": "大褂,实验室外套", + "618": "长柄勺", + "619": "灯罩", + "620": "笔记本电脑", + "621": "割草机", + "622": "镜头盖", + "623": "开信刀,裁纸刀", + "624": "图书馆", + "625": "救生艇", + "626": "点火器,打火机", + "627": "豪华轿车", + "628": "远洋班轮", + "629": "唇膏,口红", + "630": "平底便鞋", + "631": "洗剂", + "632": "扬声器", + "633": "放大镜", + "634": "锯木厂", + "635": "磁罗盘", + "636": "邮袋", + "637": "信箱", + "638": "女游泳衣", + "639": "有肩带浴衣", + "640": "窨井盖", + "641": "沙球(一种打击乐器)", + "642": "马林巴木琴", + "643": "面膜", + "644": "火柴", + "645": "花柱", + "646": "迷宫", + "647": "量杯", + "648": "药箱", + "649": "巨石,巨石结构", + "650": "麦克风", + "651": "微波炉", + "652": "军装", + "653": "奶桶", + "654": "迷你巴士", + "655": "迷你裙", + "656": "面包车", + "657": "导弹", + "658": "连指手套", + "659": "搅拌钵", + "660": "活动房屋(由汽车拖拉的)", + "661": "T型发动机小汽车", + "662": "调制解调器", + "663": "修道院", + "664": "显示器", + "665": "电瓶车", + "666": "砂浆", + "667": "学士", + "668": "清真寺", + "669": "蚊帐", + "670": "摩托车", + "671": "山地自行车", + "672": "登山帐", + "673": "鼠标,电脑鼠标", + "674": "捕鼠器", + "675": "搬家车", + "676": "口套", + "677": "钉子", + "678": "颈托", + "679": "项链", + "680": "乳头(瓶)", + "681": "笔记本,笔记本电脑", + "682": "方尖碑", + "683": "双簧管", + "684": "陶笛,卵形笛", + "685": "里程表", + "686": "滤油器", + "687": "风琴,管风琴", + "688": "示波器", + "689": "罩裙", + "690": "牛车", + "691": "氧气面罩", + "692": "包装", + "693": "船桨", + "694": "明轮,桨轮", + "695": "挂锁,扣锁", + "696": "画笔", + "697": "睡衣", + "698": "宫殿", + "699": "排箫,鸣管", + "700": "纸巾", + "701": "降落伞", + "702": "双杠", + "703": "公园长椅", + "704": "停车收费表,停车计时器", + "705": "客车,教练车", + "706": "露台,阳台", + "707": "付费电话", + "708": "基座,基脚", + "709": "铅笔盒", + "710": "卷笔刀", + "711": "香水(瓶)", + "712": "培养皿", + "713": "复印机", + "714": "拨弦片,拨子", + "715": "尖顶头盔", + "716": "栅栏,栅栏", + "717": "皮卡,皮卡车", + "718": "桥墩", + "719": "存钱罐", + "720": "药瓶", + "721": "枕头", + "722": "乒乓球", + "723": "风车", + "724": "海盗船", + "725": "水罐", + "726": "木工刨", + "727": "天文馆", + "728": "塑料袋", + "729": "板架", + "730": "犁型铲雪机", + "731": "手压皮碗泵", + "732": "宝丽来相机", + "733": "电线杆", + "734": "警车,巡逻车", + "735": "雨披", + "736": "台球桌", + "737": "充气饮料瓶", + "738": "花盆", + "739": "陶工旋盘", + "740": "电钻", + "741": "祈祷垫,地毯", + "742": "打印机", + "743": "监狱", + "744": "炮弹,导弹", + "745": "投影仪", + "746": "冰球", + "747": "沙包,吊球", + "748": "钱包", + "749": "羽管笔", + "750": "被子", + "751": "赛车", + "752": "球拍", + "753": "散热器", + "754": "收音机", + "755": "射电望远镜,无线电反射器", + "756": "雨桶", + "757": "休闲车,房车", + "758": "卷轴,卷筒", + "759": "反射式照相机", + "760": "冰箱,冰柜", + "761": "遥控器", + "762": "餐厅,饮食店,食堂", + "763": "左轮手枪", + "764": "步枪", + "765": "摇椅", + "766": "电转烤肉架", + "767": "橡皮", + "768": "橄榄球", + "769": "直尺", + "770": "跑步鞋", + "771": "保险柜", + "772": "安全别针", + "773": "盐瓶(调味用)", + "774": "凉鞋", + "775": "纱笼,围裙", + "776": "萨克斯管", + "777": "剑鞘", + "778": "秤,称重机", + "779": "校车", + "780": "帆船", + "781": "记分牌", + "782": "屏幕", + "783": "螺丝", + "784": "螺丝刀", + "785": "安全带", + "786": "缝纫机", + "787": "盾牌,盾牌", + "788": "皮鞋店,鞋店", + "789": "障子", + "790": "购物篮", + "791": "购物车", + "792": "铁锹", + "793": "浴帽", + "794": "浴帘", + "795": "滑雪板", + "796": "滑雪面罩", + "797": "睡袋", + "798": "滑尺", + "799": "滑动门", + "800": "角子老虎机", + "801": "潜水通气管", + "802": "雪橇", + "803": "扫雪机,扫雪机", + "804": "皂液器", + "805": "足球", + "806": "袜子", + "807": "碟式太阳能,太阳能集热器,太阳能炉", + "808": "宽边帽", + "809": "汤碗", + "810": "空格键", + "811": "空间加热器", + "812": "航天飞机", + "813": "铲(搅拌或涂敷用的)", + "814": "快艇", + "815": "蜘蛛网", + "816": "纺锤,纱锭", + "817": "跑车", + "818": "聚光灯", + "819": "舞台", + "820": "蒸汽机车", + "821": "钢拱桥", + "822": "钢滚筒", + "823": "听诊器", + "824": "女用披肩", + "825": "石头墙", + "826": "秒表", + "827": "火炉", + "828": "过滤器", + "829": "有轨电车,电车", + "830": "担架", + "831": "沙发床", + "832": "佛塔", + "833": "潜艇,潜水艇", + "834": "套装,衣服", + "835": "日晷", + "836": "太阳镜", + "837": "太阳镜,墨镜", + "838": "防晒霜,防晒剂", + "839": "悬索桥", + "840": "拖把", + "841": "运动衫", + "842": "游泳裤", + "843": "秋千", + "844": "开关,电器开关", + "845": "注射器", + "846": "台灯", + "847": "坦克,装甲战车,装甲战斗车辆", + "848": "磁带播放器", + "849": "茶壶", + "850": "泰迪,泰迪熊", + "851": "电视", + "852": "网球", + "853": "茅草,茅草屋顶", + "854": "幕布,剧院的帷幕", + "855": "顶针", + "856": "脱粒机", + "857": "宝座", + "858": "瓦屋顶", + "859": "烤面包机", + "860": "烟草店,烟草", + "861": "马桶", + "862": "火炬", + "863": "图腾柱", + "864": "拖车,牵引车,清障车", + "865": "玩具店", + "866": "拖拉机", + "867": "拖车,铰接式卡车", + "868": "托盘", + "869": "风衣", + "870": "三轮车", + "871": "三体船", + "872": "三脚架", + "873": "凯旋门", + "874": "无轨电车", + "875": "长号", + "876": "浴盆,浴缸", + "877": "旋转式栅门", + "878": "打字机键盘", + "879": "伞", + "880": "独轮车", + "881": "直立式钢琴", + "882": "真空吸尘器", + "883": "花瓶", + "884": "拱顶", + "885": "天鹅绒", + "886": "自动售货机", + "887": "祭服", + "888": "高架桥", + "889": "小提琴,小提琴", + "890": "排球", + "891": "松饼机", + "892": "挂钟", + "893": "钱包,皮夹", + "894": "衣柜,壁橱", + "895": "军用飞机", + "896": "洗脸盆,洗手盆", + "897": "洗衣机,自动洗衣机", + "898": "水瓶", + "899": "水壶", + "900": "水塔", + "901": "威士忌壶", + "902": "哨子", + "903": "假发", + "904": "纱窗", + "905": "百叶窗", + "906": "温莎领带", + "907": "葡萄酒瓶", + "908": "飞机翅膀,飞机", + "909": "炒菜锅", + "910": "木制的勺子", + "911": "毛织品,羊绒", + "912": "栅栏,围栏", + "913": "沉船", + "914": "双桅船", + "915": "蒙古包", + "916": "网站,互联网网站", + "917": "漫画", + "918": "纵横字谜", + "919": "路标", + "920": "交通信号灯", + "921": "防尘罩,书皮", + "922": "菜单", + "923": "盘子", + "924": "鳄梨酱", + "925": "清汤", + "926": "罐焖土豆烧肉", + "927": "蛋糕", + "928": "冰淇淋", + "929": "雪糕,冰棍,冰棒", + "930": "法式面包", + "931": "百吉饼", + "932": "椒盐脆饼", + "933": "芝士汉堡", + "934": "热狗", + "935": "土豆泥", + "936": "结球甘蓝", + "937": "西兰花", + "938": "菜花", + "939": "绿皮密生西葫芦", + "940": "西葫芦", + "941": "小青南瓜", + "942": "南瓜", + "943": "黄瓜", + "944": "朝鲜蓟", + "945": "甜椒", + "946": "刺棘蓟", + "947": "蘑菇", + "948": "绿苹果", + "949": "草莓", + "950": "橘子", + "951": "柠檬", + "952": "无花果", + "953": "菠萝", + "954": "香蕉", + "955": "菠萝蜜", + "956": "蛋奶冻苹果", + "957": "石榴", + "958": "干草", + "959": "烤面条加干酪沙司", + "960": "巧克力酱,巧克力糖浆", + "961": "面团", + "962": "瑞士肉包,肉饼", + "963": "披萨,披萨饼", + "964": "馅饼", + "965": "卷饼", + "966": "红葡萄酒", + "967": "意大利浓咖啡", + "968": "杯子", + "969": "蛋酒", + "970": "高山", + "971": "泡泡", + "972": "悬崖", + "973": "珊瑚礁", + "974": "间歇泉", + "975": "湖边,湖岸", + "976": "海角", + "977": "沙洲,沙坝", + "978": "海滨,海岸", + "979": "峡谷", + "980": "火山", + "981": "棒球,棒球运动员", + "982": "新郎", + "983": "潜水员", + "984": "油菜", + "985": "雏菊", + "986": "杓兰", + "987": "玉米", + "988": "橡子", + "989": "玫瑰果", + "990": "七叶树果实", + "991": "珊瑚菌", + "992": "木耳", + "993": "鹿花菌", + "994": "鬼笔菌", + "995": "地星(菌类)", + "996": "多叶奇果菌", + "997": "牛肝菌", + "998": "玉米穗", + "999": "卫生纸" +} diff --git a/labels/id2label_en.json b/labels/id2label_en.json new file mode 100644 index 0000000000000000000000000000000000000000..3639047285cdac0fe0de6f42b6be109bbe72f369 --- /dev/null +++ b/labels/id2label_en.json @@ -0,0 +1,1002 @@ +{ + "0": "tench, Tinca tinca", + "1": "goldfish, Carassius auratus", + "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", + "3": "tiger shark, Galeocerdo cuvieri", + "4": "hammerhead, hammerhead shark", + "5": "electric ray, crampfish, numbfish, torpedo", + "6": "stingray", + "7": "cock", + "8": "hen", + "9": "ostrich, Struthio camelus", + "10": "brambling, Fringilla montifringilla", + "11": "goldfinch, Carduelis carduelis", + "12": "house finch, linnet, Carpodacus mexicanus", + "13": "junco, snowbird", + "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea", + "15": "robin, American robin, Turdus migratorius", + "16": "bulbul", + "17": "jay", + "18": "magpie", + "19": "chickadee", + "20": "water ouzel, dipper", + "21": "kite", + "22": "bald eagle, American eagle, Haliaeetus leucocephalus", + "23": "vulture", + "24": "great grey owl, great gray owl, Strix nebulosa", + "25": "European fire salamander, Salamandra salamandra", + "26": "common newt, Triturus vulgaris", + "27": "eft", + "28": "spotted salamander, Ambystoma maculatum", + "29": "axolotl, mud puppy, Ambystoma mexicanum", + "30": "bullfrog, Rana catesbeiana", + "31": "tree frog, tree-frog", + "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui", + "33": "loggerhead, loggerhead turtle, Caretta caretta", + "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", + "35": "mud turtle", + "36": "terrapin", + "37": "box turtle, box tortoise", + "38": "banded gecko", + "39": "common iguana, iguana, Iguana iguana", + "40": "American chameleon, anole, Anolis carolinensis", + "41": "whiptail, whiptail lizard", + "42": "agama", + "43": "frilled lizard, Chlamydosaurus kingi", + "44": "alligator lizard", + "45": "Gila monster, Heloderma suspectum", + "46": "green lizard, Lacerta viridis", + "47": "African chameleon, Chamaeleo chamaeleon", + "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", + "49": "African crocodile, Nile crocodile, Crocodylus niloticus", + "50": "American alligator, Alligator mississipiensis", + "51": "triceratops", + "52": "thunder snake, worm snake, Carphophis amoenus", + "53": "ringneck snake, ring-necked snake, ring snake", + "54": "hognose snake, puff adder, sand viper", + "55": "green snake, grass snake", + "56": "king snake, kingsnake", + "57": "garter snake, grass snake", + "58": "water snake", + "59": "vine snake", + "60": "night snake, Hypsiglena torquata", + "61": "boa constrictor, Constrictor constrictor", + "62": "rock python, rock snake, Python sebae", + "63": "Indian cobra, Naja naja", + "64": "green mamba", + "65": "sea snake", + "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus", + "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus", + "68": "sidewinder, horned rattlesnake, Crotalus cerastes", + "69": "trilobite", + "70": "harvestman, daddy longlegs, Phalangium opilio", + "71": "scorpion", + "72": "black and gold garden spider, Argiope aurantia", + "73": "barn spider, Araneus cavaticus", + "74": "garden spider, Aranea diademata", + "75": "black widow, Latrodectus mactans", + "76": "tarantula", + "77": "wolf spider, hunting spider", + "78": "tick", + "79": "centipede", + "80": "black grouse", + "81": "ptarmigan", + "82": "ruffed grouse, partridge, Bonasa umbellus", + "83": "prairie chicken, prairie grouse, prairie fowl", + "84": "peacock", + "85": "quail", + "86": "partridge", + "87": "African grey, African gray, Psittacus erithacus", + "88": "macaw", + "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", + "90": "lorikeet", + "91": "coucal", + "92": "bee eater", + "93": "hornbill", + "94": "hummingbird", + "95": "jacamar", + "96": "toucan", + "97": "drake", + "98": "red-breasted merganser, Mergus serrator", + "99": "goose", + "100": "black swan, Cygnus atratus", + "101": "tusker", + "102": "echidna, spiny anteater, anteater", + "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", + "104": "wallaby, brush kangaroo", + "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", + "106": "wombat", + "107": "jellyfish", + "108": "sea anemone, anemone", + "109": "brain coral", + "110": "flatworm, platyhelminth", + "111": "nematode, nematode worm, roundworm", + "112": "conch", + "113": "snail", + "114": "slug", + "115": "sea slug, nudibranch", + "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore", + "117": "chambered nautilus, pearly nautilus, nautilus", + "118": "Dungeness crab, Cancer magister", + "119": "rock crab, Cancer irroratus", + "120": "fiddler crab", + "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", + "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus", + "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", + "124": "crayfish, crawfish, crawdad, crawdaddy", + "125": "hermit crab", + "126": "isopod", + "127": "white stork, Ciconia ciconia", + "128": "black stork, Ciconia nigra", + "129": "spoonbill", + "130": "flamingo", + "131": "little blue heron, Egretta caerulea", + "132": "American egret, great white heron, Egretta albus", + "133": "bittern", + "134": "crane", + "135": "limpkin, Aramus pictus", + "136": "European gallinule, Porphyrio porphyrio", + "137": "American coot, marsh hen, mud hen, water hen, Fulica americana", + "138": "bustard", + "139": "ruddy turnstone, Arenaria interpres", + "140": "red-backed sandpiper, dunlin, Erolia alpina", + "141": "redshank, Tringa totanus", + "142": "dowitcher", + "143": "oystercatcher, oyster catcher", + "144": "pelican", + "145": "king penguin, Aptenodytes patagonica", + "146": "albatross, mollymawk", + "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", + "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca", + "149": "dugong, Dugong dugon", + "150": "sea lion", + "151": "Chihuahua", + "152": "Japanese spaniel", + "153": "Maltese dog, Maltese terrier, Maltese", + "154": "Pekinese, Pekingese, Peke", + "155": "Shih-Tzu", + "156": "Blenheim spaniel", + "157": "papillon", + "158": "toy terrier", + "159": "Rhodesian ridgeback", + "160": "Afghan hound, Afghan", + "161": "basset, basset hound", + "162": "beagle", + "163": "bloodhound, sleuthhound", + "164": "bluetick", + "165": "black-and-tan coonhound", + "166": "Walker hound, Walker foxhound", + "167": "English foxhound", + "168": "redbone", + "169": "borzoi, Russian wolfhound", + "170": "Irish wolfhound", + "171": "Italian greyhound", + "172": "whippet", + "173": "Ibizan hound, Ibizan Podenco", + "174": "Norwegian elkhound, elkhound", + "175": "otterhound, otter hound", + "176": "Saluki, gazelle hound", + "177": "Scottish deerhound, deerhound", + "178": "Weimaraner", + "179": "Staffordshire bullterrier, Staffordshire bull terrier", + "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", + "181": "Bedlington terrier", + "182": "Border terrier", + "183": "Kerry blue terrier", + "184": "Irish terrier", + "185": "Norfolk terrier", + "186": "Norwich terrier", + "187": "Yorkshire terrier", + "188": "wire-haired fox terrier", + "189": "Lakeland terrier", + "190": "Sealyham terrier, Sealyham", + "191": "Airedale, Airedale terrier", + "192": "cairn, cairn terrier", + "193": "Australian terrier", + "194": "Dandie Dinmont, Dandie Dinmont terrier", + "195": "Boston bull, Boston terrier", + "196": "miniature schnauzer", + "197": "giant schnauzer", + "198": "standard schnauzer", + "199": "Scotch terrier, Scottish terrier, Scottie", + "200": "Tibetan terrier, chrysanthemum dog", + "201": "silky terrier, Sydney silky", + "202": "soft-coated wheaten terrier", + "203": "West Highland white terrier", + "204": "Lhasa, Lhasa apso", + "205": "flat-coated retriever", + "206": "curly-coated retriever", + "207": "golden retriever", + "208": "Labrador retriever", + "209": "Chesapeake Bay retriever", + "210": "German short-haired pointer", + "211": "vizsla, Hungarian pointer", + "212": "English setter", + "213": "Irish setter, red setter", + "214": "Gordon setter", + "215": "Brittany spaniel", + "216": "clumber, clumber spaniel", + "217": "English springer, English springer spaniel", + "218": "Welsh springer spaniel", + "219": "cocker spaniel, English cocker spaniel, cocker", + "220": "Sussex spaniel", + "221": "Irish water spaniel", + "222": "kuvasz", + "223": "schipperke", + "224": "groenendael", + "225": "malinois", + "226": "briard", + "227": "kelpie", + "228": "komondor", + "229": "Old English sheepdog, bobtail", + "230": "Shetland sheepdog, Shetland sheep dog, Shetland", + "231": "collie", + "232": "Border collie", + "233": "Bouvier des Flandres, Bouviers des Flandres", + "234": "Rottweiler", + "235": "German shepherd, German shepherd dog, German police dog, alsatian", + "236": "Doberman, Doberman pinscher", + "237": "miniature pinscher", + "238": "Greater Swiss Mountain dog", + "239": "Bernese mountain dog", + "240": "Appenzeller", + "241": "EntleBucher", + "242": "boxer", + "243": "bull mastiff", + "244": "Tibetan mastiff", + "245": "French bulldog", + "246": "Great Dane", + "247": "Saint Bernard, St Bernard", + "248": "Eskimo dog, husky", + "249": "malamute, malemute, Alaskan malamute", + "250": "Siberian husky", + "251": "dalmatian, coach dog, carriage dog", + "252": "affenpinscher, monkey pinscher, monkey dog", + "253": "basenji", + "254": "pug, pug-dog", + "255": "Leonberg", + "256": "Newfoundland, Newfoundland dog", + "257": "Great Pyrenees", + "258": "Samoyed, Samoyede", + "259": "Pomeranian", + "260": "chow, chow chow", + "261": "keeshond", + "262": "Brabancon griffon", + "263": "Pembroke, Pembroke Welsh corgi", + "264": "Cardigan, Cardigan Welsh corgi", + "265": "toy poodle", + "266": "miniature poodle", + "267": "standard poodle", + "268": "Mexican hairless", + "269": "timber wolf, grey wolf, gray wolf, Canis lupus", + "270": "white wolf, Arctic wolf, Canis lupus tundrarum", + "271": "red wolf, maned wolf, Canis rufus, Canis niger", + "272": "coyote, prairie wolf, brush wolf, Canis latrans", + "273": "dingo, warrigal, warragal, Canis dingo", + "274": "dhole, Cuon alpinus", + "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", + "276": "hyena, hyaena", + "277": "red fox, Vulpes vulpes", + "278": "kit fox, Vulpes macrotis", + "279": "Arctic fox, white fox, Alopex lagopus", + "280": "grey fox, gray fox, Urocyon cinereoargenteus", + "281": "tabby, tabby cat", + "282": "tiger cat", + "283": "Persian cat", + "284": "Siamese cat, Siamese", + "285": "Egyptian cat", + "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", + "287": "lynx, catamount", + "288": "leopard, Panthera pardus", + "289": "snow leopard, ounce, Panthera uncia", + "290": "jaguar, panther, Panthera onca, Felis onca", + "291": "lion, king of beasts, Panthera leo", + "292": "tiger, Panthera tigris", + "293": "cheetah, chetah, Acinonyx jubatus", + "294": "brown bear, bruin, Ursus arctos", + "295": "American black bear, black bear, Ursus americanus, Euarctos americanus", + "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus", + "297": "sloth bear, Melursus ursinus, Ursus ursinus", + "298": "mongoose", + "299": "meerkat, mierkat", + "300": "tiger beetle", + "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle", + "302": "ground beetle, carabid beetle", + "303": "long-horned beetle, longicorn, longicorn beetle", + "304": "leaf beetle, chrysomelid", + "305": "dung beetle", + "306": "rhinoceros beetle", + "307": "weevil", + "308": "fly", + "309": "bee", + "310": "ant, emmet, pismire", + "311": "grasshopper, hopper", + "312": "cricket", + "313": "walking stick, walkingstick, stick insect", + "314": "cockroach, roach", + "315": "mantis, mantid", + "316": "cicada, cicala", + "317": "leafhopper", + "318": "lacewing, lacewing fly", + "319": "dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", + "320": "damselfly", + "321": "admiral", + "322": "ringlet, ringlet butterfly", + "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", + "324": "cabbage butterfly", + "325": "sulphur butterfly, sulfur butterfly", + "326": "lycaenid, lycaenid butterfly", + "327": "starfish, sea star", + "328": "sea urchin", + "329": "sea cucumber, holothurian", + "330": "wood rabbit, cottontail, cottontail rabbit", + "331": "hare", + "332": "Angora, Angora rabbit", + "333": "hamster", + "334": "porcupine, hedgehog", + "335": "fox squirrel, eastern fox squirrel, Sciurus niger", + "336": "marmot", + "337": "beaver", + "338": "guinea pig, Cavia cobaya", + "339": "sorrel", + "340": "zebra", + "341": "hog, pig, grunter, squealer, Sus scrofa", + "342": "wild boar, boar, Sus scrofa", + "343": "warthog", + "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius", + "345": "ox", + "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis", + "347": "bison", + "348": "ram, tup", + "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", + "350": "ibex, Capra ibex", + "351": "hartebeest", + "352": "impala, Aepyceros melampus", + "353": "gazelle", + "354": "Arabian camel, dromedary, Camelus dromedarius", + "355": "llama", + "356": "weasel", + "357": "mink", + "358": "polecat, fitch, foulmart, foumart, Mustela putorius", + "359": "black-footed ferret, ferret, Mustela nigripes", + "360": "otter", + "361": "skunk, polecat, wood pussy", + "362": "badger", + "363": "armadillo", + "364": "three-toed sloth, ai, Bradypus tridactylus", + "365": "orangutan, orang, orangutang, Pongo pygmaeus", + "366": "gorilla, Gorilla gorilla", + "367": "chimpanzee, chimp, Pan troglodytes", + "368": "gibbon, Hylobates lar", + "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus", + "370": "guenon, guenon monkey", + "371": "patas, hussar monkey, Erythrocebus patas", + "372": "baboon", + "373": "macaque", + "374": "langur", + "375": "colobus, colobus monkey", + "376": "proboscis monkey, Nasalis larvatus", + "377": "marmoset", + "378": "capuchin, ringtail, Cebus capucinus", + "379": "howler monkey, howler", + "380": "titi, titi monkey", + "381": "spider monkey, Ateles geoffroyi", + "382": "squirrel monkey, Saimiri sciureus", + "383": "Madagascar cat, ring-tailed lemur, Lemur catta", + "384": "indri, indris, Indri indri, Indri brevicaudatus", + "385": "Indian elephant, Elephas maximus", + "386": "African elephant, Loxodonta africana", + "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens", + "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca", + "389": "barracouta, snoek", + "390": "eel", + "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", + "392": "rock beauty, Holocanthus tricolor", + "393": "anemone fish", + "394": "sturgeon", + "395": "gar, garfish, garpike, billfish, Lepisosteus osseus", + "396": "lionfish", + "397": "puffer, pufferfish, blowfish, globefish", + "398": "abacus", + "399": "abaya", + "400": "academic gown, academic robe, judge robe", + "401": "accordion, piano accordion, squeeze box", + "402": "acoustic guitar", + "403": "aircraft carrier, carrier, flattop, attack aircraft carrier", + "404": "airliner", + "405": "airship, dirigible", + "406": "altar", + "407": "ambulance", + "408": "amphibian, amphibious vehicle", + "409": "analog clock", + "410": "apiary, bee house", + "411": "apron", + "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", + "413": "assault rifle, assault gun", + "414": "backpack, back pack, knapsack, packsack, rucksack, haversack", + "415": "bakery, bakeshop, bakehouse", + "416": "balance beam, beam", + "417": "balloon", + "418": "ballpoint, ballpoint pen, ballpen, Biro", + "419": "Band Aid", + "420": "banjo", + "421": "bannister, banister, balustrade, balusters, handrail", + "422": "barbell", + "423": "barber chair", + "424": "barbershop", + "425": "barn", + "426": "barometer", + "427": "barrel, cask", + "428": "barrow, garden cart, lawn cart, wheelbarrow", + "429": "baseball", + "430": "basketball", + "431": "bassinet", + "432": "bassoon", + "433": "bathing cap, swimming cap", + "434": "bath towel", + "435": "bathtub, bathing tub, bath, tub", + "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", + "437": "beacon, lighthouse, beacon light, pharos", + "438": "beaker", + "439": "bearskin, busby, shako", + "440": "beer bottle", + "441": "beer glass", + "442": "bell cote, bell cot", + "443": "bib", + "444": "bicycle-built-for-two, tandem bicycle, tandem", + "445": "bikini, two-piece", + "446": "binder, ring-binder", + "447": "binoculars, field glasses, opera glasses", + "448": "birdhouse", + "449": "boathouse", + "450": "bobsled, bobsleigh, bob", + "451": "bolo tie, bolo, bola tie, bola", + "452": "bonnet, poke bonnet", + "453": "bookcase", + "454": "bookshop, bookstore, bookstall", + "455": "bottlecap", + "456": "bow", + "457": "bow tie, bow-tie, bowtie", + "458": "brass, memorial tablet, plaque", + "459": "brassiere, bra, bandeau", + "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty", + "461": "breastplate, aegis, egis", + "462": "broom", + "463": "bucket, pail", + "464": "buckle", + "465": "bulletproof vest", + "466": "bullet train, bullet", + "467": "butcher shop, meat market", + "468": "cab, hack, taxi, taxicab", + "469": "caldron, cauldron", + "470": "candle, taper, wax light", + "471": "cannon", + "472": "canoe", + "473": "can opener, tin opener", + "474": "cardigan", + "475": "car mirror", + "476": "carousel, carrousel, merry-go-round, roundabout, whirligig", + "477": "carpenters kit, tool kit", + "478": "carton", + "479": "car wheel", + "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", + "481": "cassette", + "482": "cassette player", + "483": "castle", + "484": "catamaran", + "485": "CD player", + "486": "cello, violoncello", + "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone", + "488": "chain", + "489": "chainlink fence", + "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", + "491": "chain saw, chainsaw", + "492": "chest", + "493": "chiffonier, commode", + "494": "chime, bell, gong", + "495": "china cabinet, china closet", + "496": "Christmas stocking", + "497": "church, church building", + "498": "cinema, movie theater, movie theatre, movie house, picture palace", + "499": "cleaver, meat cleaver, chopper", + "500": "cliff dwelling", + "501": "cloak", + "502": "clog, geta, patten, sabot", + "503": "cocktail shaker", + "504": "coffee mug", + "505": "coffeepot", + "506": "coil, spiral, volute, whorl, helix", + "507": "combination lock", + "508": "computer keyboard, keypad", + "509": "confectionery, confectionary, candy store", + "510": "container ship, containership, container vessel", + "511": "convertible", + "512": "corkscrew, bottle screw", + "513": "cornet, horn, trumpet, trump", + "514": "cowboy boot", + "515": "cowboy hat, ten-gallon hat", + "516": "cradle", + "517": "crane", + "518": "crash helmet", + "519": "crate", + "520": "crib, cot", + "521": "Crock Pot", + "522": "croquet ball", + "523": "crutch", + "524": "cuirass", + "525": "dam, dike, dyke", + "526": "desk", + "527": "desktop computer", + "528": "dial telephone, dial phone", + "529": "diaper, nappy, napkin", + "530": "digital clock", + "531": "digital watch", + "532": "dining table, board", + "533": "dishrag, dishcloth", + "534": "dishwasher, dish washer, dishwashing machine", + "535": "disk brake, disc brake", + "536": "dock, dockage, docking facility", + "537": "dogsled, dog sled, dog sleigh", + "538": "dome", + "539": "doormat, welcome mat", + "540": "drilling platform, offshore rig", + "541": "drum, membranophone, tympan", + "542": "drumstick", + "543": "dumbbell", + "544": "Dutch oven", + "545": "electric fan, blower", + "546": "electric guitar", + "547": "electric locomotive", + "548": "entertainment center", + "549": "envelope", + "550": "espresso maker", + "551": "face powder", + "552": "feather boa, boa", + "553": "file, file cabinet, filing cabinet", + "554": "fireboat", + "555": "fire engine, fire truck", + "556": "fire screen, fireguard", + "557": "flagpole, flagstaff", + "558": "flute, transverse flute", + "559": "folding chair", + "560": "football helmet", + "561": "forklift", + "562": "fountain", + "563": "fountain pen", + "564": "four-poster", + "565": "freight car", + "566": "French horn, horn", + "567": "frying pan, frypan, skillet", + "568": "fur coat", + "569": "garbage truck, dustcart", + "570": "gasmask, respirator, gas helmet", + "571": "gas pump, gasoline pump, petrol pump, island dispenser", + "572": "goblet", + "573": "go-kart", + "574": "golf ball", + "575": "golfcart, golf cart", + "576": "gondola", + "577": "gong, tam-tam", + "578": "gown", + "579": "grand piano, grand", + "580": "greenhouse, nursery, glasshouse", + "581": "grille, radiator grille", + "582": "grocery store, grocery, food market, market", + "583": "guillotine", + "584": "hair slide", + "585": "hair spray", + "586": "half track", + "587": "hammer", + "588": "hamper", + "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier", + "590": "hand-held computer, hand-held microcomputer", + "591": "handkerchief, hankie, hanky, hankey", + "592": "hard disc, hard disk, fixed disk", + "593": "harmonica, mouth organ, harp, mouth harp", + "594": "harp", + "595": "harvester, reaper", + "596": "hatchet", + "597": "holster", + "598": "home theater, home theatre", + "599": "honeycomb", + "600": "hook, claw", + "601": "hoopskirt, crinoline", + "602": "horizontal bar, high bar", + "603": "horse cart, horse-cart", + "604": "hourglass", + "605": "iPod", + "606": "iron, smoothing iron", + "607": "jack-o-lantern", + "608": "jean, blue jean, denim", + "609": "jeep, landrover", + "610": "jersey, T-shirt, tee shirt", + "611": "jigsaw puzzle", + "612": "jinrikisha, ricksha, rickshaw", + "613": "joystick", + "614": "kimono", + "615": "knee pad", + "616": "knot", + "617": "lab coat, laboratory coat", + "618": "ladle", + "619": "lampshade, lamp shade", + "620": "laptop, laptop computer", + "621": "lawn mower, mower", + "622": "lens cap, lens cover", + "623": "letter opener, paper knife, paperknife", + "624": "library", + "625": "lifeboat", + "626": "lighter, light, igniter, ignitor", + "627": "limousine, limo", + "628": "liner, ocean liner", + "629": "lipstick, lip rouge", + "630": "Loafer", + "631": "lotion", + "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", + "633": "loupe, jewelers loupe", + "634": "lumbermill, sawmill", + "635": "magnetic compass", + "636": "mailbag, postbag", + "637": "mailbox, letter box", + "638": "maillot", + "639": "maillot, tank suit", + "640": "manhole cover", + "641": "maraca", + "642": "marimba, xylophone", + "643": "mask", + "644": "matchstick", + "645": "maypole", + "646": "maze, labyrinth", + "647": "measuring cup", + "648": "medicine chest, medicine cabinet", + "649": "megalith, megalithic structure", + "650": "microphone, mike", + "651": "microwave, microwave oven", + "652": "military uniform", + "653": "milk can", + "654": "minibus", + "655": "miniskirt, mini", + "656": "minivan", + "657": "missile", + "658": "mitten", + "659": "mixing bowl", + "660": "mobile home, manufactured home", + "661": "Model T", + "662": "modem", + "663": "monastery", + "664": "monitor", + "665": "moped", + "666": "mortar", + "667": "mortarboard", + "668": "mosque", + "669": "mosquito net", + "670": "motor scooter, scooter", + "671": "mountain bike, all-terrain bike, off-roader", + "672": "mountain tent", + "673": "mouse, computer mouse", + "674": "mousetrap", + "675": "moving van", + "676": "muzzle", + "677": "nail", + "678": "neck brace", + "679": "necklace", + "680": "nipple", + "681": "notebook, notebook computer", + "682": "obelisk", + "683": "oboe, hautboy, hautbois", + "684": "ocarina, sweet potato", + "685": "odometer, hodometer, mileometer, milometer", + "686": "oil filter", + "687": "organ, pipe organ", + "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO", + "689": "overskirt", + "690": "oxcart", + "691": "oxygen mask", + "692": "packet", + "693": "paddle, boat paddle", + "694": "paddlewheel, paddle wheel", + "695": "padlock", + "696": "paintbrush", + "697": "pajama, pyjama, pjs, jammies", + "698": "palace", + "699": "panpipe, pandean pipe, syrinx", + "700": "paper towel", + "701": "parachute, chute", + "702": "parallel bars, bars", + "703": "park bench", + "704": "parking meter", + "705": "passenger car, coach, carriage", + "706": "patio, terrace", + "707": "pay-phone, pay-station", + "708": "pedestal, plinth, footstall", + "709": "pencil box, pencil case", + "710": "pencil sharpener", + "711": "perfume, essence", + "712": "Petri dish", + "713": "photocopier", + "714": "pick, plectrum, plectron", + "715": "pickelhaube", + "716": "picket fence, paling", + "717": "pickup, pickup truck", + "718": "pier", + "719": "piggy bank, penny bank", + "720": "pill bottle", + "721": "pillow", + "722": "ping-pong ball", + "723": "pinwheel", + "724": "pirate, pirate ship", + "725": "pitcher, ewer", + "726": "plane, carpenters plane, woodworking plane", + "727": "planetarium", + "728": "plastic bag", + "729": "plate rack", + "730": "plow, plough", + "731": "plunger, plumbers helper", + "732": "Polaroid camera, Polaroid Land camera", + "733": "pole", + "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", + "735": "poncho", + "736": "pool table, billiard table, snooker table", + "737": "pop bottle, soda bottle", + "738": "pot, flowerpot", + "739": "potters wheel", + "740": "power drill", + "741": "prayer rug, prayer mat", + "742": "printer", + "743": "prison, prison house", + "744": "projectile, missile", + "745": "projector", + "746": "puck, hockey puck", + "747": "punching bag, punch bag, punching ball, punchball", + "748": "purse", + "749": "quill, quill pen", + "750": "quilt, comforter, comfort, puff", + "751": "racer, race car, racing car", + "752": "racket, racquet", + "753": "radiator", + "754": "radio, wireless", + "755": "radio telescope, radio reflector", + "756": "rain barrel", + "757": "recreational vehicle, RV, R.V.", + "758": "reel", + "759": "reflex camera", + "760": "refrigerator, icebox", + "761": "remote control, remote", + "762": "restaurant, eating house, eating place, eatery", + "763": "revolver, six-gun, six-shooter", + "764": "rifle", + "765": "rocking chair, rocker", + "766": "rotisserie", + "767": "rubber eraser, rubber, pencil eraser", + "768": "rugby ball", + "769": "rule, ruler", + "770": "running shoe", + "771": "safe", + "772": "safety pin", + "773": "saltshaker, salt shaker", + "774": "sandal", + "775": "sarong", + "776": "sax, saxophone", + "777": "scabbard", + "778": "scale, weighing machine", + "779": "school bus", + "780": "schooner", + "781": "scoreboard", + "782": "screen, CRT screen", + "783": "screw", + "784": "screwdriver", + "785": "seat belt, seatbelt", + "786": "sewing machine", + "787": "shield, buckler", + "788": "shoe shop, shoe-shop, shoe store", + "789": "shoji", + "790": "shopping basket", + "791": "shopping cart", + "792": "shovel", + "793": "shower cap", + "794": "shower curtain", + "795": "ski", + "796": "ski mask", + "797": "sleeping bag", + "798": "slide rule, slipstick", + "799": "sliding door", + "800": "slot, one-armed bandit", + "801": "snorkel", + "802": "snowmobile", + "803": "snowplow, snowplough", + "804": "soap dispenser", + "805": "soccer ball", + "806": "sock", + "807": "solar dish, solar collector, solar furnace", + "808": "sombrero", + "809": "soup bowl", + "810": "space bar", + "811": "space heater", + "812": "space shuttle", + "813": "spatula", + "814": "speedboat", + "815": "spider web, spiders web", + "816": "spindle", + "817": "sports car, sport car", + "818": "spotlight, spot", + "819": "stage", + "820": "steam locomotive", + "821": "steel arch bridge", + "822": "steel drum", + "823": "stethoscope", + "824": "stole", + "825": "stone wall", + "826": "stopwatch, stop watch", + "827": "stove", + "828": "strainer", + "829": "streetcar, tram, tramcar, trolley, trolley car", + "830": "stretcher", + "831": "studio couch, day bed", + "832": "stupa, tope", + "833": "submarine, pigboat, sub, U-boat", + "834": "suit, suit of clothes", + "835": "sundial", + "836": "sunglass", + "837": "sunglasses, dark glasses, shades", + "838": "sunscreen, sunblock, sun blocker", + "839": "suspension bridge", + "840": "swab, swob, mop", + "841": "sweatshirt", + "842": "swimming trunks, bathing trunks", + "843": "swing", + "844": "switch, electric switch, electrical switch", + "845": "syringe", + "846": "table lamp", + "847": "tank, army tank, armored combat vehicle, armoured combat vehicle", + "848": "tape player", + "849": "teapot", + "850": "teddy, teddy bear", + "851": "television, television system", + "852": "tennis ball", + "853": "thatch, thatched roof", + "854": "theater curtain, theatre curtain", + "855": "thimble", + "856": "thresher, thrasher, threshing machine", + "857": "throne", + "858": "tile roof", + "859": "toaster", + "860": "tobacco shop, tobacconist shop, tobacconist", + "861": "toilet seat", + "862": "torch", + "863": "totem pole", + "864": "tow truck, tow car, wrecker", + "865": "toyshop", + "866": "tractor", + "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", + "868": "tray", + "869": "trench coat", + "870": "tricycle, trike, velocipede", + "871": "trimaran", + "872": "tripod", + "873": "triumphal arch", + "874": "trolleybus, trolley coach, trackless trolley", + "875": "trombone", + "876": "tub, vat", + "877": "turnstile", + "878": "typewriter keyboard", + "879": "umbrella", + "880": "unicycle, monocycle", + "881": "upright, upright piano", + "882": "vacuum, vacuum cleaner", + "883": "vase", + "884": "vault", + "885": "velvet", + "886": "vending machine", + "887": "vestment", + "888": "viaduct", + "889": "violin, fiddle", + "890": "volleyball", + "891": "waffle iron", + "892": "wall clock", + "893": "wallet, billfold, notecase, pocketbook", + "894": "wardrobe, closet, press", + "895": "warplane, military plane", + "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin", + "897": "washer, automatic washer, washing machine", + "898": "water bottle", + "899": "water jug", + "900": "water tower", + "901": "whiskey jug", + "902": "whistle", + "903": "wig", + "904": "window screen", + "905": "window shade", + "906": "Windsor tie", + "907": "wine bottle", + "908": "wing", + "909": "wok", + "910": "wooden spoon", + "911": "wool, woolen, woollen", + "912": "worm fence, snake fence, snake-rail fence, Virginia fence", + "913": "wreck", + "914": "yawl", + "915": "yurt", + "916": "web site, website, internet site, site", + "917": "comic book", + "918": "crossword puzzle, crossword", + "919": "street sign", + "920": "traffic light, traffic signal, stoplight", + "921": "book jacket, dust cover, dust jacket, dust wrapper", + "922": "menu", + "923": "plate", + "924": "guacamole", + "925": "consomme", + "926": "hot pot, hotpot", + "927": "trifle", + "928": "ice cream, icecream", + "929": "ice lolly, lolly, lollipop, popsicle", + "930": "French loaf", + "931": "bagel, beigel", + "932": "pretzel", + "933": "cheeseburger", + "934": "hotdog, hot dog, red hot", + "935": "mashed potato", + "936": "head cabbage", + "937": "broccoli", + "938": "cauliflower", + "939": "zucchini, courgette", + "940": "spaghetti squash", + "941": "acorn squash", + "942": "butternut squash", + "943": "cucumber, cuke", + "944": "artichoke, globe artichoke", + "945": "bell pepper", + "946": "cardoon", + "947": "mushroom", + "948": "Granny Smith", + "949": "strawberry", + "950": "orange", + "951": "lemon", + "952": "fig", + "953": "pineapple, ananas", + "954": "banana", + "955": "jackfruit, jak, jack", + "956": "custard apple", + "957": "pomegranate", + "958": "hay", + "959": "carbonara", + "960": "chocolate sauce, chocolate syrup", + "961": "dough", + "962": "meat loaf, meatloaf", + "963": "pizza, pizza pie", + "964": "potpie", + "965": "burrito", + "966": "red wine", + "967": "espresso", + "968": "cup", + "969": "eggnog", + "970": "alp", + "971": "bubble", + "972": "cliff, drop, drop-off", + "973": "coral reef", + "974": "geyser", + "975": "lakeside, lakeshore", + "976": "promontory, headland, head, foreland", + "977": "sandbar, sand bar", + "978": "seashore, coast, seacoast, sea-coast", + "979": "valley, vale", + "980": "volcano", + "981": "ballplayer, baseball player", + "982": "groom, bridegroom", + "983": "scuba diver", + "984": "rapeseed", + "985": "daisy", + "986": "yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", + "987": "corn", + "988": "acorn", + "989": "hip, rose hip, rosehip", + "990": "buckeye, horse chestnut, conker", + "991": "coral fungus", + "992": "agaric", + "993": "gyromitra", + "994": "stinkhorn, carrion fungus", + "995": "earthstar", + "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", + "997": "bolete", + "998": "ear, spike, capitulum", + "999": "toilet tissue, toilet paper, bathroom tissue" +} diff --git a/labels/imagenet_labels.py b/labels/imagenet_labels.py new file mode 100644 index 0000000000000000000000000000000000000000..398a385ceecc61ca14b66ac9f64fb1d50c7423e7 --- /dev/null +++ b/labels/imagenet_labels.py @@ -0,0 +1,61 @@ +"""ImageNet-1k class labels for JiT class-conditional generation. + +Labels are stored as Hugging Face-style ``id2label`` JSON maps (string keys ``"0"``–``"999"``). +Each value is a comma-separated list of synonyms for that class id. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Literal + +Language = Literal["en", "cn"] + +_LABELS_DIR = Path(__file__).resolve().parent + + +def load_id2label( + labels_dir: Path | str | None = None, + lang: Language = "en", +) -> dict[int, str]: + """Load ``id2label`` from ``id2label_en.json`` or ``id2label_cn.json``.""" + root = Path(labels_dir) if labels_dir is not None else _LABELS_DIR + filename = "id2label_en.json" if lang == "en" else "id2label_cn.json" + path = root / filename + if not path.exists(): + raise FileNotFoundError(f"ImageNet label file not found: {path}") + + raw = json.loads(path.read_text(encoding="utf-8")) + return {int(key): value for key, value in raw.items()} + + +def build_label2id(id2label: dict[int, str]) -> dict[str, int]: + """Build a synonym -> class id map from an ``id2label`` dict (DiT-style).""" + labels: dict[str, int] = {} + for class_id, value in id2label.items(): + for synonym in value.split(","): + synonym = synonym.strip() + if synonym: + labels[synonym] = int(class_id) + return dict(sorted(labels.items())) + + +def resolve_label_ids( + labels: str | list[str], + label2id: dict[str, int], + *, + lang: Language = "en", +) -> list[int]: + """Map one or more label strings to ImageNet class ids.""" + if isinstance(labels, str): + labels = [labels] + + missing = [label for label in labels if label not in label2id] + if missing: + preview = ", ".join(list(label2id.keys())[:8]) + raise ValueError( + f"Unknown label(s) for lang={lang!r}: {missing}. " + f"Example valid labels: {preview}, ..." + ) + return [label2id[label] for label in labels]