NiT-diffusers / NiT-B /pipeline.py
BiliSakura's picture
Upload folder using huggingface_hub
4f51e55 verified
"""Hub custom pipeline: NiTPipeline.
Load with native Hugging Face diffusers and trust_remote_code=True.
"""
from __future__ import annotations
import inspect
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union, Any
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from diffusers.utils.torch_utils import randn_tensor
DEFAULT_NATIVE_RESOLUTION = 512
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from pathlib import Path
>>> from diffusers import DiffusionPipeline
>>> import torch
>>> model_dir = Path("./NiT-XL").resolve()
>>> pipe = DiffusionPipeline.from_pretrained(
... str(model_dir),
... local_files_only=True,
... custom_pipeline=str(model_dir / "pipeline.py"),
... trust_remote_code=True,
... torch_dtype=torch.bfloat16,
... )
>>> pipe.to("cuda")
>>> print(pipe.id2label[207])
>>> print(pipe.get_label_ids("golden retriever"))
>>> generator = torch.Generator(device="cuda").manual_seed(42)
>>> image = pipe(
... class_labels="golden retriever",
... height=512,
... width=512,
... num_inference_steps=250,
... guidance_scale=2.05,
... guidance_interval=(0.0, 0.7),
... generator=generator,
... ).images[0]
```
"""
class NiTPipeline(DiffusionPipeline):
r"""
Pipeline for native-resolution class-conditional image generation with NiT.
Parameters:
transformer ([`NiTTransformer2DModel`]):
Class-conditional transformer that predicts flow-matching velocity in packed latent space.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
Flow-matching Euler scheduler used by NiT.
vae ([`AutoencoderDC`] or [`AutoencoderKL`], *optional*):
Variational autoencoder used to decode packed transformer latents to pixels.
id2label (`dict[int, str]`, *optional*):
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
"""
@staticmethod
def prepare_extra_step_kwargs(
scheduler,
generator=None,
eta: float | None = None,
):
kwargs = {}
step_params = set(inspect.signature(scheduler.step).parameters.keys())
if "generator" in step_params:
kwargs["generator"] = generator
if eta is not None and "eta" in step_params:
kwargs["eta"] = eta
return kwargs
model_cpu_offload_seq = "transformer->vae"
_optional_components = ["vae"]
def __init__(
self,
transformer,
scheduler,
vae=None,
id2label: Optional[Dict[Union[int, str], str]] = None,
):
super().__init__()
self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
self.image_processor = VaeImageProcessor()
self._id2label = self._normalize_id2label(id2label)
self.labels = self._build_label2id(self._id2label)
self._labels_loaded_from_model_index = bool(self._id2label)
def _ensure_labels_loaded(self) -> None:
if self._labels_loaded_from_model_index:
return
loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
if loaded:
self._id2label = loaded
self.labels = self._build_label2id(self._id2label)
self._labels_loaded_from_model_index = True
@staticmethod
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
if not id2label:
return {}
return {int(key): value for key, value in id2label.items()}
@staticmethod
def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
if not variant_path:
return {}
variant_dir = Path(variant_path).resolve()
model_index_path = variant_dir / "model_index.json"
if not model_index_path.exists():
return {}
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
id2label = raw.get("id2label")
if not isinstance(id2label, dict):
return {}
return {int(key): value for key, value in id2label.items()}
@staticmethod
def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
label2id: Dict[str, int] = {}
for class_id, value in id2label.items():
for synonym in value.split(","):
synonym = synonym.strip()
if synonym:
label2id[synonym] = int(class_id)
return dict(sorted(label2id.items()))
@property
def id2label(self) -> Dict[int, str]:
r"""ImageNet class id to English label string (comma-separated synonyms)."""
self._ensure_labels_loaded()
return self._id2label
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
r"""
Map ImageNet label strings to class ids.
Args:
label (`str` or `list[str]`):
One or more English label strings. Each string must match a synonym in `id2label`.
"""
self._ensure_labels_loaded()
label2id = self.labels
if not label2id:
raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
if isinstance(label, str):
label = [label]
missing = [item for item in label if item not in label2id]
if missing:
preview = ", ".join(list(label2id.keys())[:8])
raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
return [label2id[item] for item in label]
def _normalize_class_labels(
self,
class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
) -> torch.LongTensor:
if torch.is_tensor(class_labels):
return class_labels.to(device=self._execution_device, dtype=torch.long).reshape(-1)
if isinstance(class_labels, int):
class_label_ids = [class_labels]
elif isinstance(class_labels, str):
class_label_ids = self.get_label_ids(class_labels)
elif class_labels and isinstance(class_labels[0], str):
class_label_ids = self.get_label_ids(class_labels)
else:
class_label_ids = list(class_labels)
return torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
def _get_vae_spatial_downsample(self) -> int:
if self.vae is None:
return 1
if self.vae.__class__.__name__ == "AutoencoderDC" or "dc-ae" in getattr(
self.vae.config, "_name_or_path", ""
):
return 32
block_out_channels = getattr(self.vae.config, "block_out_channels", [0, 0, 0, 0])
return 2 ** (len(block_out_channels) - 1)
def check_inputs(
self,
height: int,
width: int,
num_inference_steps: int,
output_type: str,
interpolation: Optional[str] = None,
ori_max_pe_len: Optional[int] = None,
) -> None:
if num_inference_steps < 1:
raise ValueError("num_inference_steps must be >= 1.")
if output_type not in {"pil", "np", "pt", "latent"}:
raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
if interpolation is not None and interpolation not in {
"no",
"linear",
"ntk-aware",
"ntk-by-parts",
"yarn",
"ntk-aware-pro1",
"ntk-aware-pro2",
"scale1",
"scale2",
}:
raise ValueError(f"Unsupported interpolation mode: {interpolation!r}.")
if interpolation not in {None, "no"} and ori_max_pe_len is None:
raise ValueError("ori_max_pe_len is required when interpolation is enabled.")
spatial_downsample = self._get_vae_spatial_downsample()
if height % spatial_downsample != 0 or width % spatial_downsample != 0:
raise ValueError(
f"height and width must be divisible by the VAE downsample factor {spatial_downsample}."
)
patch_size = int(self.transformer.config.patch_size)
latent_height = height // spatial_downsample
latent_width = width // spatial_downsample
if latent_height % patch_size != 0 or latent_width % patch_size != 0:
raise ValueError("Latent height and width must be divisible by transformer's patch_size.")
def prepare_latents(
self,
batch_size: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
generator: Optional[Union[torch.Generator, List[torch.Generator]]],
) -> Tuple[torch.Tensor, torch.LongTensor]:
spatial_downsample = self._get_vae_spatial_downsample()
latent_height = height // spatial_downsample
latent_width = width // spatial_downsample
patch_size = int(self.transformer.config.patch_size)
token_height = latent_height // patch_size
token_width = latent_width // patch_size
image_sizes = torch.tensor([[token_height, token_width]] * batch_size, device=device, dtype=torch.long)
packed_shape = (
batch_size * token_height * token_width,
self.transformer.config.in_channels,
patch_size,
patch_size,
)
packed_latents = randn_tensor(
packed_shape,
generator=generator,
device=device,
dtype=dtype,
)
return packed_latents, image_sizes
def _maybe_configure_rope_extrapolation(
self,
height: int,
width: int,
interpolation: Optional[str],
ori_max_pe_len: Optional[int],
decouple: bool,
) -> None:
if interpolation in {None, "no"}:
return
spatial_downsample = self._get_vae_spatial_downsample()
patch_size = int(self.transformer.config.patch_size)
latent_h = height // spatial_downsample // patch_size
latent_w = width // spatial_downsample // patch_size
self.transformer.configure_rope_extrapolation(
custom_freqs=interpolation,
max_pe_len_h=latent_h,
max_pe_len_w=latent_w,
ori_max_pe_len=int(ori_max_pe_len),
decouple=decouple,
)
def _apply_classifier_free_guidance(
self,
model_output: torch.Tensor,
guidance_scale: float,
guidance_active: bool,
) -> torch.Tensor:
if guidance_scale <= 1.0 or not guidance_active:
return model_output
model_output_cond, model_output_uncond = model_output.chunk(2)
return model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
def _get_vae_dtype(self, latents: torch.Tensor) -> torch.dtype:
vae_dtype = getattr(self.vae, "dtype", None)
if vae_dtype is not None:
return vae_dtype
vae_params = next(self.vae.parameters(), None)
return vae_params.dtype if vae_params is not None else latents.dtype
def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
if self.vae is None:
if output_type == "latent":
return latents
raise ValueError("Cannot decode latents without a VAE.")
vae_dtype = self._get_vae_dtype(latents)
latents = latents.to(dtype=vae_dtype)
scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0)
shift_factor = getattr(self.vae.config, "shift_factor", 0.0)
latents = (latents / scaling_factor) + shift_factor
if output_type == "latent":
return latents
image = self.vae.decode(latents, return_dict=False)[0]
return self.image_processor.postprocess(image, output_type=output_type)
@torch.inference_mode()
def __call__(
self,
class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor],
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 1.0,
guidance_interval: Tuple[float, float] = (0.0, 1.0),
interpolation: Optional[str] = None,
ori_max_pe_len: Optional[int] = None,
decouple: bool = False,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: str = "pil",
return_dict: bool = True,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Generate class-conditional images at native resolution.
Args:
class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`):
ImageNet class indices or human-readable English label strings.
height (`int`, *optional*):
Output image height in pixels. Defaults to `512` when a VAE is present.
width (`int`, *optional*):
Output image width in pixels. Defaults to `512` when a VAE is present.
num_inference_steps (`int`, defaults to `50`):
Number of denoising steps.
guidance_scale (`float`, defaults to `1.0`):
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
guidance_interval (`tuple[float, float]`, defaults to `(0.0, 1.0)`):
Flow-time interval where CFG is applied.
interpolation (`str`, *optional*):
VisionYaRN / VisionNTK extrapolation mode. Use `"yarn"` for VisionYaRN or
`"ntk-aware"`, `"ntk-by-parts"`, `"ntk-aware-pro1"`, `"ntk-aware-pro2"`,
`"scale1"`, or `"scale2"` for VisionNTK variants. Pass `"no"` or omit to use
the transformer's configured RoPE.
ori_max_pe_len (`int`, *optional*):
Original maximum latent side length seen during training. Required when
`interpolation` is enabled.
decouple (`bool`, defaults to `False`):
Whether to decouple height and width when computing extrapolated RoPE frequencies.
generator (`torch.Generator`, *optional*):
RNG for reproducibility.
output_type (`str`, defaults to `"pil"`):
`"pil"`, `"np"`, `"pt"`, or `"latent"`.
return_dict (`bool`, defaults to `True`):
Return [`ImagePipelineOutput`] if True.
"""
default_size = DEFAULT_NATIVE_RESOLUTION if self.vae is not None else 256
height = int(height or default_size)
width = int(width or default_size)
self.check_inputs(height, width, num_inference_steps, output_type, interpolation, ori_max_pe_len)
self._maybe_configure_rope_extrapolation(height, width, interpolation, ori_max_pe_len, decouple)
device = self._execution_device
model_dtype = next(self.transformer.parameters()).dtype
class_labels_tensor = self._normalize_class_labels(class_labels)
batch_size = class_labels_tensor.numel()
packed_latents, image_sizes = self.prepare_latents(
batch_size, height, width, model_dtype, device, generator
)
self.scheduler.set_timesteps(num_inference_steps, device=device)
num_train_timesteps = self.scheduler.config.num_train_timesteps
if getattr(self.scheduler.config, "stochastic_sampling", False):
raise ValueError(
"NiT expects deterministic FlowMatchEulerDiscreteScheduler stepping "
"(scheduler.config.stochastic_sampling=False). The scheduler's stochastic_sampling "
"path uses a different update rule than the official NiT Euler-Maruyama SDE and "
"produces salt-and-pepper noise."
)
null_labels = torch.full_like(class_labels_tensor, self.transformer.config.num_classes)
guidance_low, guidance_high = guidance_interval
for t in self.progress_bar(self.scheduler.timesteps):
flow_time = float(t) / num_train_timesteps
guidance_active = guidance_low <= flow_time <= guidance_high
if guidance_scale > 1.0 and guidance_active:
model_input = torch.cat([packed_latents, packed_latents], dim=0)
labels = torch.cat([class_labels_tensor, null_labels], dim=0)
model_image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
else:
model_input = packed_latents
labels = class_labels_tensor
model_image_sizes = image_sizes
timestep_batch = torch.full((labels.numel(),), flow_time, device=device, dtype=model_dtype)
model_output = self.transformer(
model_input.to(dtype=model_dtype),
timestep_batch,
labels,
image_sizes=model_image_sizes,
return_dict=True,
).sample
model_output = self._apply_classifier_free_guidance(model_output, guidance_scale, guidance_active)
packed_latents = self.scheduler.step(
model_output,
t,
packed_latents,
generator=generator,
).prev_sample
latents = self.transformer._unpack_latents(packed_latents, image_sizes)
image = self.decode_latents(latents, output_type=output_type)
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
NiTPipelineOutput = ImagePipelineOutput