Instructions to use BiliSakura/NiT-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/NiT-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/NiT-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| """Hub custom pipeline: NiTPipeline. | |
| Load with native Hugging Face diffusers and trust_remote_code=True. | |
| """ | |
| from __future__ import annotations | |
| import inspect | |
| # Copyright 2026 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import json | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple, Union, Any | |
| import torch | |
| from diffusers.image_processor import VaeImageProcessor | |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput | |
| from diffusers.utils.torch_utils import randn_tensor | |
| DEFAULT_NATIVE_RESOLUTION = 512 | |
| EXAMPLE_DOC_STRING = """ | |
| Examples: | |
| ```py | |
| >>> from pathlib import Path | |
| >>> from diffusers import DiffusionPipeline | |
| >>> import torch | |
| >>> model_dir = Path("./NiT-XL").resolve() | |
| >>> pipe = DiffusionPipeline.from_pretrained( | |
| ... str(model_dir), | |
| ... local_files_only=True, | |
| ... custom_pipeline=str(model_dir / "pipeline.py"), | |
| ... trust_remote_code=True, | |
| ... torch_dtype=torch.bfloat16, | |
| ... ) | |
| >>> pipe.to("cuda") | |
| >>> print(pipe.id2label[207]) | |
| >>> print(pipe.get_label_ids("golden retriever")) | |
| >>> generator = torch.Generator(device="cuda").manual_seed(42) | |
| >>> image = pipe( | |
| ... class_labels="golden retriever", | |
| ... height=512, | |
| ... width=512, | |
| ... num_inference_steps=250, | |
| ... guidance_scale=2.05, | |
| ... guidance_interval=(0.0, 0.7), | |
| ... generator=generator, | |
| ... ).images[0] | |
| ``` | |
| """ | |
| class NiTPipeline(DiffusionPipeline): | |
| r""" | |
| Pipeline for native-resolution class-conditional image generation with NiT. | |
| Parameters: | |
| transformer ([`NiTTransformer2DModel`]): | |
| Class-conditional transformer that predicts flow-matching velocity in packed latent space. | |
| scheduler ([`FlowMatchEulerDiscreteScheduler`]): | |
| Flow-matching Euler scheduler used by NiT. | |
| vae ([`AutoencoderDC`] or [`AutoencoderKL`], *optional*): | |
| Variational autoencoder used to decode packed transformer latents to pixels. | |
| id2label (`dict[int, str]`, *optional*): | |
| ImageNet class id to English label mapping. Values may contain comma-separated synonyms. | |
| """ | |
| def prepare_extra_step_kwargs( | |
| scheduler, | |
| generator=None, | |
| eta: float | None = None, | |
| ): | |
| kwargs = {} | |
| step_params = set(inspect.signature(scheduler.step).parameters.keys()) | |
| if "generator" in step_params: | |
| kwargs["generator"] = generator | |
| if eta is not None and "eta" in step_params: | |
| kwargs["eta"] = eta | |
| return kwargs | |
| model_cpu_offload_seq = "transformer->vae" | |
| _optional_components = ["vae"] | |
| def __init__( | |
| self, | |
| transformer, | |
| scheduler, | |
| vae=None, | |
| id2label: Optional[Dict[Union[int, str], str]] = None, | |
| ): | |
| super().__init__() | |
| self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae) | |
| self.image_processor = VaeImageProcessor() | |
| self._id2label = self._normalize_id2label(id2label) | |
| self.labels = self._build_label2id(self._id2label) | |
| self._labels_loaded_from_model_index = bool(self._id2label) | |
| def _ensure_labels_loaded(self) -> None: | |
| if self._labels_loaded_from_model_index: | |
| return | |
| loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None)) | |
| if loaded: | |
| self._id2label = loaded | |
| self.labels = self._build_label2id(self._id2label) | |
| self._labels_loaded_from_model_index = True | |
| def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]: | |
| if not id2label: | |
| return {} | |
| return {int(key): value for key, value in id2label.items()} | |
| def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]: | |
| if not variant_path: | |
| return {} | |
| variant_dir = Path(variant_path).resolve() | |
| model_index_path = variant_dir / "model_index.json" | |
| if not model_index_path.exists(): | |
| return {} | |
| raw = json.loads(model_index_path.read_text(encoding="utf-8")) | |
| id2label = raw.get("id2label") | |
| if not isinstance(id2label, dict): | |
| return {} | |
| return {int(key): value for key, value in id2label.items()} | |
| def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]: | |
| label2id: Dict[str, int] = {} | |
| for class_id, value in id2label.items(): | |
| for synonym in value.split(","): | |
| synonym = synonym.strip() | |
| if synonym: | |
| label2id[synonym] = int(class_id) | |
| return dict(sorted(label2id.items())) | |
| def id2label(self) -> Dict[int, str]: | |
| r"""ImageNet class id to English label string (comma-separated synonyms).""" | |
| self._ensure_labels_loaded() | |
| return self._id2label | |
| def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: | |
| r""" | |
| Map ImageNet label strings to class ids. | |
| Args: | |
| label (`str` or `list[str]`): | |
| One or more English label strings. Each string must match a synonym in `id2label`. | |
| """ | |
| self._ensure_labels_loaded() | |
| label2id = self.labels | |
| if not label2id: | |
| raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.") | |
| if isinstance(label, str): | |
| label = [label] | |
| missing = [item for item in label if item not in label2id] | |
| if missing: | |
| preview = ", ".join(list(label2id.keys())[:8]) | |
| raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...") | |
| return [label2id[item] for item in label] | |
| def _normalize_class_labels( | |
| self, | |
| class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor], | |
| ) -> torch.LongTensor: | |
| if torch.is_tensor(class_labels): | |
| return class_labels.to(device=self._execution_device, dtype=torch.long).reshape(-1) | |
| if isinstance(class_labels, int): | |
| class_label_ids = [class_labels] | |
| elif isinstance(class_labels, str): | |
| class_label_ids = self.get_label_ids(class_labels) | |
| elif class_labels and isinstance(class_labels[0], str): | |
| class_label_ids = self.get_label_ids(class_labels) | |
| else: | |
| class_label_ids = list(class_labels) | |
| return torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1) | |
| def _get_vae_spatial_downsample(self) -> int: | |
| if self.vae is None: | |
| return 1 | |
| if self.vae.__class__.__name__ == "AutoencoderDC" or "dc-ae" in getattr( | |
| self.vae.config, "_name_or_path", "" | |
| ): | |
| return 32 | |
| block_out_channels = getattr(self.vae.config, "block_out_channels", [0, 0, 0, 0]) | |
| return 2 ** (len(block_out_channels) - 1) | |
| def check_inputs( | |
| self, | |
| height: int, | |
| width: int, | |
| num_inference_steps: int, | |
| output_type: str, | |
| interpolation: Optional[str] = None, | |
| ori_max_pe_len: Optional[int] = None, | |
| ) -> None: | |
| if num_inference_steps < 1: | |
| raise ValueError("num_inference_steps must be >= 1.") | |
| if output_type not in {"pil", "np", "pt", "latent"}: | |
| raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.") | |
| if interpolation is not None and interpolation not in { | |
| "no", | |
| "linear", | |
| "ntk-aware", | |
| "ntk-by-parts", | |
| "yarn", | |
| "ntk-aware-pro1", | |
| "ntk-aware-pro2", | |
| "scale1", | |
| "scale2", | |
| }: | |
| raise ValueError(f"Unsupported interpolation mode: {interpolation!r}.") | |
| if interpolation not in {None, "no"} and ori_max_pe_len is None: | |
| raise ValueError("ori_max_pe_len is required when interpolation is enabled.") | |
| spatial_downsample = self._get_vae_spatial_downsample() | |
| if height % spatial_downsample != 0 or width % spatial_downsample != 0: | |
| raise ValueError( | |
| f"height and width must be divisible by the VAE downsample factor {spatial_downsample}." | |
| ) | |
| patch_size = int(self.transformer.config.patch_size) | |
| latent_height = height // spatial_downsample | |
| latent_width = width // spatial_downsample | |
| if latent_height % patch_size != 0 or latent_width % patch_size != 0: | |
| raise ValueError("Latent height and width must be divisible by transformer's patch_size.") | |
| def prepare_latents( | |
| self, | |
| batch_size: int, | |
| height: int, | |
| width: int, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]], | |
| ) -> Tuple[torch.Tensor, torch.LongTensor]: | |
| spatial_downsample = self._get_vae_spatial_downsample() | |
| latent_height = height // spatial_downsample | |
| latent_width = width // spatial_downsample | |
| patch_size = int(self.transformer.config.patch_size) | |
| token_height = latent_height // patch_size | |
| token_width = latent_width // patch_size | |
| image_sizes = torch.tensor([[token_height, token_width]] * batch_size, device=device, dtype=torch.long) | |
| packed_shape = ( | |
| batch_size * token_height * token_width, | |
| self.transformer.config.in_channels, | |
| patch_size, | |
| patch_size, | |
| ) | |
| packed_latents = randn_tensor( | |
| packed_shape, | |
| generator=generator, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| return packed_latents, image_sizes | |
| def _maybe_configure_rope_extrapolation( | |
| self, | |
| height: int, | |
| width: int, | |
| interpolation: Optional[str], | |
| ori_max_pe_len: Optional[int], | |
| decouple: bool, | |
| ) -> None: | |
| if interpolation in {None, "no"}: | |
| return | |
| spatial_downsample = self._get_vae_spatial_downsample() | |
| patch_size = int(self.transformer.config.patch_size) | |
| latent_h = height // spatial_downsample // patch_size | |
| latent_w = width // spatial_downsample // patch_size | |
| self.transformer.configure_rope_extrapolation( | |
| custom_freqs=interpolation, | |
| max_pe_len_h=latent_h, | |
| max_pe_len_w=latent_w, | |
| ori_max_pe_len=int(ori_max_pe_len), | |
| decouple=decouple, | |
| ) | |
| def _apply_classifier_free_guidance( | |
| self, | |
| model_output: torch.Tensor, | |
| guidance_scale: float, | |
| guidance_active: bool, | |
| ) -> torch.Tensor: | |
| if guidance_scale <= 1.0 or not guidance_active: | |
| return model_output | |
| model_output_cond, model_output_uncond = model_output.chunk(2) | |
| return model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond) | |
| def _get_vae_dtype(self, latents: torch.Tensor) -> torch.dtype: | |
| vae_dtype = getattr(self.vae, "dtype", None) | |
| if vae_dtype is not None: | |
| return vae_dtype | |
| vae_params = next(self.vae.parameters(), None) | |
| return vae_params.dtype if vae_params is not None else latents.dtype | |
| def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"): | |
| if self.vae is None: | |
| if output_type == "latent": | |
| return latents | |
| raise ValueError("Cannot decode latents without a VAE.") | |
| vae_dtype = self._get_vae_dtype(latents) | |
| latents = latents.to(dtype=vae_dtype) | |
| scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0) | |
| shift_factor = getattr(self.vae.config, "shift_factor", 0.0) | |
| latents = (latents / scaling_factor) + shift_factor | |
| if output_type == "latent": | |
| return latents | |
| image = self.vae.decode(latents, return_dict=False)[0] | |
| return self.image_processor.postprocess(image, output_type=output_type) | |
| def __call__( | |
| self, | |
| class_labels: Union[int, str, List[Union[int, str]], torch.LongTensor], | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 1.0, | |
| guidance_interval: Tuple[float, float] = (0.0, 1.0), | |
| interpolation: Optional[str] = None, | |
| ori_max_pe_len: Optional[int] = None, | |
| decouple: bool = False, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| output_type: str = "pil", | |
| return_dict: bool = True, | |
| ) -> Union[ImagePipelineOutput, Tuple]: | |
| r""" | |
| Generate class-conditional images at native resolution. | |
| Args: | |
| class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`): | |
| ImageNet class indices or human-readable English label strings. | |
| height (`int`, *optional*): | |
| Output image height in pixels. Defaults to `512` when a VAE is present. | |
| width (`int`, *optional*): | |
| Output image width in pixels. Defaults to `512` when a VAE is present. | |
| num_inference_steps (`int`, defaults to `50`): | |
| Number of denoising steps. | |
| guidance_scale (`float`, defaults to `1.0`): | |
| Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`. | |
| guidance_interval (`tuple[float, float]`, defaults to `(0.0, 1.0)`): | |
| Flow-time interval where CFG is applied. | |
| interpolation (`str`, *optional*): | |
| VisionYaRN / VisionNTK extrapolation mode. Use `"yarn"` for VisionYaRN or | |
| `"ntk-aware"`, `"ntk-by-parts"`, `"ntk-aware-pro1"`, `"ntk-aware-pro2"`, | |
| `"scale1"`, or `"scale2"` for VisionNTK variants. Pass `"no"` or omit to use | |
| the transformer's configured RoPE. | |
| ori_max_pe_len (`int`, *optional*): | |
| Original maximum latent side length seen during training. Required when | |
| `interpolation` is enabled. | |
| decouple (`bool`, defaults to `False`): | |
| Whether to decouple height and width when computing extrapolated RoPE frequencies. | |
| generator (`torch.Generator`, *optional*): | |
| RNG for reproducibility. | |
| output_type (`str`, defaults to `"pil"`): | |
| `"pil"`, `"np"`, `"pt"`, or `"latent"`. | |
| return_dict (`bool`, defaults to `True`): | |
| Return [`ImagePipelineOutput`] if True. | |
| """ | |
| default_size = DEFAULT_NATIVE_RESOLUTION if self.vae is not None else 256 | |
| height = int(height or default_size) | |
| width = int(width or default_size) | |
| self.check_inputs(height, width, num_inference_steps, output_type, interpolation, ori_max_pe_len) | |
| self._maybe_configure_rope_extrapolation(height, width, interpolation, ori_max_pe_len, decouple) | |
| device = self._execution_device | |
| model_dtype = next(self.transformer.parameters()).dtype | |
| class_labels_tensor = self._normalize_class_labels(class_labels) | |
| batch_size = class_labels_tensor.numel() | |
| packed_latents, image_sizes = self.prepare_latents( | |
| batch_size, height, width, model_dtype, device, generator | |
| ) | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| num_train_timesteps = self.scheduler.config.num_train_timesteps | |
| if getattr(self.scheduler.config, "stochastic_sampling", False): | |
| raise ValueError( | |
| "NiT expects deterministic FlowMatchEulerDiscreteScheduler stepping " | |
| "(scheduler.config.stochastic_sampling=False). The scheduler's stochastic_sampling " | |
| "path uses a different update rule than the official NiT Euler-Maruyama SDE and " | |
| "produces salt-and-pepper noise." | |
| ) | |
| null_labels = torch.full_like(class_labels_tensor, self.transformer.config.num_classes) | |
| guidance_low, guidance_high = guidance_interval | |
| for t in self.progress_bar(self.scheduler.timesteps): | |
| flow_time = float(t) / num_train_timesteps | |
| guidance_active = guidance_low <= flow_time <= guidance_high | |
| if guidance_scale > 1.0 and guidance_active: | |
| model_input = torch.cat([packed_latents, packed_latents], dim=0) | |
| labels = torch.cat([class_labels_tensor, null_labels], dim=0) | |
| model_image_sizes = torch.cat([image_sizes, image_sizes], dim=0) | |
| else: | |
| model_input = packed_latents | |
| labels = class_labels_tensor | |
| model_image_sizes = image_sizes | |
| timestep_batch = torch.full((labels.numel(),), flow_time, device=device, dtype=model_dtype) | |
| model_output = self.transformer( | |
| model_input.to(dtype=model_dtype), | |
| timestep_batch, | |
| labels, | |
| image_sizes=model_image_sizes, | |
| return_dict=True, | |
| ).sample | |
| model_output = self._apply_classifier_free_guidance(model_output, guidance_scale, guidance_active) | |
| packed_latents = self.scheduler.step( | |
| model_output, | |
| t, | |
| packed_latents, | |
| generator=generator, | |
| ).prev_sample | |
| latents = self.transformer._unpack_latents(packed_latents, image_sizes) | |
| image = self.decode_latents(latents, output_type=output_type) | |
| self.maybe_free_model_hooks() | |
| if not return_dict: | |
| return (image,) | |
| return ImagePipelineOutput(images=image) | |
| NiTPipelineOutput = ImagePipelineOutput |