Instructions to use BiliSakura/JiT-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/JiT-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| import argparse | |
| from collections.abc import Mapping | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, Literal, Tuple | |
| import torch | |
| from diffusers import ConfigMixin, ModelMixin | |
| from diffusers.configuration_utils import register_to_config | |
| from diffusers.models.modeling_outputs import Transformer2DModelOutput | |
| from .modeling_jit_backbone import JiT_models | |
| def _extract_module_state_dict( | |
| state_dict: Dict[str, torch.Tensor], prefixes: Tuple[str, ...] = ("transformer.", "net.") | |
| ) -> Dict[str, torch.Tensor]: | |
| """Extract module state by stripping the first fully-matching prefix. | |
| Prefix precedence is left-to-right; `"transformer."` is preferred over legacy `"net."`. | |
| """ | |
| for prefix in prefixes: | |
| if all(key.startswith(prefix) for key in state_dict.keys()): | |
| return {k[len(prefix):]: v for k, v in state_dict.items()} | |
| return state_dict | |
| def _build_jit_kwargs( | |
| image_size: int, | |
| num_classes: int, | |
| attn_dropout: float, | |
| proj_dropout: float, | |
| model_name: str | None = None, | |
| ) -> Dict[str, object]: | |
| # Keep model_name for backward-compatible internal call signatures. | |
| _ = model_name | |
| return { | |
| "input_size": image_size, | |
| "in_channels": 3, | |
| "num_classes": num_classes, | |
| "attn_drop": attn_dropout, | |
| "proj_drop": proj_dropout, | |
| } | |
| class JiTCheckpointConfig: | |
| model_name: str | |
| image_size: int | |
| num_classes: int | |
| attn_dropout: float | |
| proj_dropout: float | |
| def _config_from_checkpoint(ckpt_args: argparse.Namespace | Mapping[str, Any]) -> JiTCheckpointConfig: | |
| if isinstance(ckpt_args, argparse.Namespace): | |
| args_dict = vars(ckpt_args) | |
| elif isinstance(ckpt_args, Mapping): | |
| args_dict = ckpt_args | |
| else: | |
| raise TypeError(f"Unsupported checkpoint args type: {type(ckpt_args)}") | |
| def _get_first_available(*keys: str, default=None): | |
| for key in keys: | |
| if key in args_dict and args_dict[key] is not None: | |
| return args_dict[key] | |
| return default | |
| model_name = _get_first_available("model", "model_name", "model_type") | |
| image_size = _get_first_available("img_size", "image_size", "sample_size") | |
| num_classes = _get_first_available("class_num", "num_classes", "num_class_embeds") | |
| if model_name is None or image_size is None or num_classes is None: | |
| raise ValueError("Checkpoint args are missing model/image_size/num_classes information.") | |
| return JiTCheckpointConfig( | |
| model_name=str(model_name), | |
| image_size=int(image_size), | |
| num_classes=int(num_classes), | |
| attn_dropout=float(_get_first_available("attn_dropout", "attention_dropout", default=0.0)), | |
| proj_dropout=float(_get_first_available("proj_dropout", "dropout", default=0.0)), | |
| ) | |
| class JiTTransformer2DModel(ModelMixin, ConfigMixin): | |
| def __init__( | |
| self, | |
| model_type: str = "JiT-B/16", | |
| sample_size: int = 256, | |
| num_class_embeds: int = 1000, | |
| attention_dropout: float = 0.0, | |
| dropout: float = 0.0, | |
| model_name: str | None = None, | |
| image_size: int | None = None, | |
| num_classes: int | None = None, | |
| attn_dropout: float | None = None, | |
| proj_dropout: float | None = None, | |
| ): | |
| super().__init__() | |
| resolved_model_type = model_type if model_name is None else model_name | |
| resolved_sample_size = sample_size if image_size is None else image_size | |
| resolved_num_class_embeds = num_class_embeds if num_classes is None else num_classes | |
| resolved_attention_dropout = attention_dropout if attn_dropout is None else attn_dropout | |
| resolved_dropout = dropout if proj_dropout is None else proj_dropout | |
| if resolved_model_type not in JiT_models: | |
| raise ValueError(f"Unknown model '{resolved_model_type}'. Available: {list(JiT_models.keys())}") | |
| self.transformer = JiT_models[resolved_model_type]( | |
| **_build_jit_kwargs( | |
| image_size=resolved_sample_size, | |
| num_classes=resolved_num_class_embeds, | |
| attn_dropout=resolved_attention_dropout, | |
| proj_dropout=resolved_dropout, | |
| model_name=resolved_model_type, | |
| ) | |
| ) | |
| def forward( | |
| self, | |
| sample: torch.Tensor, | |
| timestep: torch.Tensor, | |
| class_labels: torch.Tensor, | |
| return_dict: bool = True, | |
| ): | |
| timestep = torch.as_tensor(timestep, device=sample.device) | |
| if timestep.ndim == 0: | |
| timestep = timestep.repeat(sample.shape[0]) | |
| else: | |
| timestep = timestep.reshape(-1) | |
| if timestep.shape[0] == 1 and sample.shape[0] > 1: | |
| timestep = timestep.repeat(sample.shape[0]) | |
| denoised = self.transformer(sample, timestep, class_labels) | |
| if not return_dict: | |
| return (denoised,) | |
| return Transformer2DModelOutput(sample=denoised) | |
| def from_jit_checkpoint( | |
| cls, | |
| checkpoint_path: str, | |
| weights: Literal["model", "ema1", "ema2"] = "ema1", | |
| map_location: str = "cpu", | |
| strict: bool = True, | |
| ) -> Tuple["JiTTransformer2DModel", Dict[str, object]]: | |
| checkpoint = torch.load(checkpoint_path, map_location=map_location) | |
| if "args" not in checkpoint: | |
| raise ValueError("Checkpoint is missing 'args', cannot infer JiT architecture config.") | |
| config = _config_from_checkpoint(checkpoint["args"]) | |
| model = cls( | |
| model_type=config.model_name, | |
| sample_size=config.image_size, | |
| num_class_embeds=config.num_classes, | |
| attention_dropout=config.attn_dropout, | |
| dropout=config.proj_dropout, | |
| ) | |
| key = "model" if weights == "model" else f"model_{weights}" | |
| if key not in checkpoint: | |
| raise ValueError(f"Checkpoint key '{key}' not found. Available keys: {list(checkpoint.keys())}") | |
| model_state = _extract_module_state_dict(checkpoint[key]) | |
| model.transformer.load_state_dict(model_state, strict=strict) | |
| metadata = { | |
| "checkpoint_path": checkpoint_path, | |
| "weights": weights, | |
| "epoch": checkpoint.get("epoch"), | |
| "source_args": checkpoint.get("args"), | |
| } | |
| return model, metadata | |
| def to_jit_checkpoint( | |
| self, | |
| ema_mode: Literal["none", "copy_to_both"] = "copy_to_both", | |
| prefix: str = "net.", | |
| ) -> Dict[str, object]: | |
| base_state = {f"{prefix}{k}": v.detach().cpu() for k, v in self.transformer.state_dict().items()} | |
| checkpoint = {"model": base_state} | |
| if ema_mode == "copy_to_both": | |
| checkpoint["model_ema1"] = {k: v.clone() for k, v in base_state.items()} | |
| checkpoint["model_ema2"] = {k: v.clone() for k, v in base_state.items()} | |
| elif ema_mode != "none": | |
| raise ValueError(f"Unsupported ema_mode='{ema_mode}'.") | |
| return checkpoint | |
| def net(self): | |
| return self.transformer | |
| def net(self, module): | |
| self.transformer = module | |
| # Backward-compatible alias. | |
| JiTDiffusersModel = JiTTransformer2DModel | |