Instructions to use BiliSakura/JiT-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/JiT-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Fix generator determinism: forward generator through scheduler steps and seeded noise
Browse files- JiT-B-16/pipeline.py +59 -151
- JiT-B-32/pipeline.py +59 -151
- JiT-H-16/pipeline.py +59 -151
- JiT-H-32/pipeline.py +59 -151
- JiT-L-16/pipeline.py +59 -151
- JiT-L-32/pipeline.py +59 -151
JiT-B-16/pipeline.py
CHANGED
|
@@ -1,36 +1,24 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
import importlib
|
| 16 |
import json
|
| 17 |
-
import sys
|
| 18 |
from pathlib import Path
|
| 19 |
-
from typing import Dict, List, Optional, Tuple, Union
|
| 20 |
|
| 21 |
import torch
|
| 22 |
-
|
| 23 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 24 |
-
from diffusers.schedulers import FlowMatchHeunDiscreteScheduler, KarrasDiffusionSchedulers
|
| 25 |
from diffusers.utils.torch_utils import randn_tensor
|
| 26 |
|
| 27 |
-
|
| 28 |
RECOMMENDED_NOISE_BY_SIZE = {
|
| 29 |
256: 1.0,
|
| 30 |
512: 2.0,
|
| 31 |
}
|
| 32 |
|
| 33 |
-
|
| 34 |
class JiTPipeline(DiffusionPipeline):
|
| 35 |
r"""
|
| 36 |
Pipeline for image generation using JiT (Just image Transformer).
|
|
@@ -44,100 +32,43 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 44 |
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
|
| 45 |
"""
|
| 46 |
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
-
def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
|
| 51 |
-
"""Load a self-contained variant folder locally or from the Hub.
|
| 52 |
-
|
| 53 |
-
Examples:
|
| 54 |
-
JiTPipeline.from_pretrained(".")
|
| 55 |
-
JiTPipeline.from_pretrained("./JiT-H-32")
|
| 56 |
-
DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
|
| 57 |
-
"""
|
| 58 |
-
repo_root = Path(__file__).resolve().parent
|
| 59 |
-
|
| 60 |
-
if pretrained_model_name_or_path in (None, "", "."):
|
| 61 |
-
variant = repo_root
|
| 62 |
-
elif (
|
| 63 |
-
isinstance(pretrained_model_name_or_path, str)
|
| 64 |
-
and "/" in pretrained_model_name_or_path
|
| 65 |
-
and not Path(pretrained_model_name_or_path).exists()
|
| 66 |
-
):
|
| 67 |
-
from huggingface_hub import snapshot_download
|
| 68 |
-
|
| 69 |
-
hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
|
| 70 |
-
if subfolder:
|
| 71 |
-
hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
|
| 72 |
-
cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
|
| 73 |
-
variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
|
| 74 |
-
else:
|
| 75 |
-
variant = Path(pretrained_model_name_or_path)
|
| 76 |
-
if not variant.is_absolute():
|
| 77 |
-
candidate = (Path.cwd() / variant).resolve()
|
| 78 |
-
variant = candidate if candidate.exists() else (repo_root / variant).resolve()
|
| 79 |
-
if subfolder:
|
| 80 |
-
variant = variant / subfolder
|
| 81 |
-
|
| 82 |
-
id2label_override = kwargs.pop("id2label", None)
|
| 83 |
-
model_kwargs = dict(kwargs)
|
| 84 |
-
inserted: List[str] = []
|
| 85 |
-
|
| 86 |
-
def _load_component(folder: str, module_name: str, class_name: str):
|
| 87 |
-
comp_dir = variant / folder
|
| 88 |
-
module_path = comp_dir / f"{module_name}.py"
|
| 89 |
-
has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
|
| 90 |
-
if not module_path.exists() or not has_weights:
|
| 91 |
-
return None
|
| 92 |
-
|
| 93 |
-
comp_path = str(comp_dir)
|
| 94 |
-
if comp_path not in sys.path:
|
| 95 |
-
sys.path.insert(0, comp_path)
|
| 96 |
-
inserted.append(comp_path)
|
| 97 |
-
|
| 98 |
-
module = importlib.import_module(module_name)
|
| 99 |
-
component_cls = getattr(module, class_name)
|
| 100 |
-
return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
|
| 101 |
-
|
| 102 |
-
try:
|
| 103 |
-
transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
|
| 104 |
-
try:
|
| 105 |
-
scheduler = FlowMatchHeunDiscreteScheduler.from_pretrained(str(variant), subfolder="scheduler")
|
| 106 |
-
except Exception:
|
| 107 |
-
scheduler = FlowMatchHeunDiscreteScheduler(shift=4.0)
|
| 108 |
-
|
| 109 |
-
if transformer is None:
|
| 110 |
-
raise ValueError(f"No loadable transformer found under {variant}")
|
| 111 |
-
|
| 112 |
-
variant_path = str(variant)
|
| 113 |
-
model_index_path = variant / "model_index.json"
|
| 114 |
-
id2label = id2label_override or cls._read_id2label_from_model_index(model_index_path)
|
| 115 |
-
|
| 116 |
-
pipe = cls(
|
| 117 |
-
transformer=transformer,
|
| 118 |
-
scheduler=scheduler,
|
| 119 |
-
id2label=id2label,
|
| 120 |
-
)
|
| 121 |
-
if variant_path and hasattr(pipe, "register_to_config"):
|
| 122 |
-
pipe.register_to_config(_name_or_path=variant_path)
|
| 123 |
-
return pipe
|
| 124 |
-
finally:
|
| 125 |
-
for comp_path in inserted:
|
| 126 |
-
if comp_path in sys.path:
|
| 127 |
-
sys.path.remove(comp_path)
|
| 128 |
|
| 129 |
def __init__(
|
| 130 |
self,
|
| 131 |
transformer,
|
| 132 |
-
scheduler
|
| 133 |
id2label: Optional[Dict[Union[int, str], str]] = None,
|
| 134 |
):
|
| 135 |
super().__init__()
|
| 136 |
scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
|
| 137 |
self.register_modules(transformer=transformer, scheduler=scheduler)
|
| 138 |
-
|
| 139 |
self._id2label = self._normalize_id2label(id2label)
|
| 140 |
self.labels = self._build_label2id(self._id2label)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
@staticmethod
|
| 143 |
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
|
|
@@ -146,7 +77,11 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 146 |
return {int(key): value for key, value in id2label.items()}
|
| 147 |
|
| 148 |
@staticmethod
|
| 149 |
-
def _read_id2label_from_model_index(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
if not model_index_path.exists():
|
| 151 |
return {}
|
| 152 |
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
|
|
@@ -167,20 +102,16 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 167 |
|
| 168 |
@property
|
| 169 |
def id2label(self) -> Dict[int, str]:
|
| 170 |
-
|
| 171 |
return self._id2label
|
| 172 |
|
| 173 |
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
|
| 174 |
-
|
| 175 |
-
Map ImageNet label strings to class ids.
|
| 176 |
-
|
| 177 |
-
Args:
|
| 178 |
-
label (`str` or `list[str]`):
|
| 179 |
-
One or more English label strings. Each string must match a synonym in `id2label`.
|
| 180 |
-
"""
|
| 181 |
label2id = self.labels
|
| 182 |
if not label2id:
|
| 183 |
-
raise ValueError(
|
|
|
|
|
|
|
| 184 |
|
| 185 |
if isinstance(label, str):
|
| 186 |
label = [label]
|
|
@@ -188,9 +119,7 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 188 |
missing = [item for item in label if item not in label2id]
|
| 189 |
if missing:
|
| 190 |
preview = ", ".join(list(label2id.keys())[:8])
|
| 191 |
-
raise ValueError(
|
| 192 |
-
f"Unknown English label(s): {missing}. Example valid labels: {preview}, ..."
|
| 193 |
-
)
|
| 194 |
return [label2id[item] for item in label]
|
| 195 |
|
| 196 |
def _normalize_class_labels(
|
|
@@ -225,33 +154,10 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 225 |
output_type: Optional[str] = "pil",
|
| 226 |
return_dict: bool = True,
|
| 227 |
) -> Union[ImagePipelineOutput, Tuple]:
|
| 228 |
-
r"""
|
| 229 |
-
Generate class-conditional images.
|
| 230 |
-
|
| 231 |
-
Args:
|
| 232 |
-
class_labels (`int`, `str`, `list[int]`, or `list[str]`):
|
| 233 |
-
ImageNet class indices or human-readable English label strings.
|
| 234 |
-
guidance_scale (`float`, *optional*):
|
| 235 |
-
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
|
| 236 |
-
guidance_interval_min (`float`, defaults to `0.1`):
|
| 237 |
-
Lower bound of the CFG interval in flow time `t in [0, 1]`.
|
| 238 |
-
guidance_interval_max (`float`, defaults to `1.0`):
|
| 239 |
-
Upper bound of the CFG interval in flow time.
|
| 240 |
-
noise_scale (`float`, *optional*):
|
| 241 |
-
Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
|
| 242 |
-
t_eps (`float`, defaults to `5e-2`):
|
| 243 |
-
Epsilon clamp for the `1 - t` denominator, matching JiT source defaults.
|
| 244 |
-
generator (`torch.Generator`, *optional*):
|
| 245 |
-
RNG for reproducibility.
|
| 246 |
-
num_inference_steps (`int`, defaults to `50`):
|
| 247 |
-
Number of solver steps (at least 2).
|
| 248 |
-
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 249 |
-
`"pil"`, `"np"`, or `"pt"`.
|
| 250 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
| 251 |
-
Return [`ImagePipelineOutput`] if True.
|
| 252 |
-
"""
|
| 253 |
if num_inference_steps < 2:
|
| 254 |
raise ValueError("num_inference_steps must be >= 2.")
|
|
|
|
|
|
|
| 255 |
|
| 256 |
class_label_ids = self._normalize_class_labels(class_labels)
|
| 257 |
do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
|
|
@@ -268,22 +174,21 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 268 |
f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
|
| 269 |
)
|
| 270 |
channels = int(self.transformer.config.in_channels)
|
| 271 |
-
null_class_val = int(
|
|
|
|
|
|
|
| 272 |
|
| 273 |
if guidance_scale is None:
|
| 274 |
guidance_scale = 1.0
|
| 275 |
if noise_scale is None:
|
| 276 |
noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
|
| 277 |
|
| 278 |
-
latents = (
|
| 279 |
-
randn_tensor(
|
| 280 |
shape=(batch_size, channels, height, width),
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
* noise_scale
|
| 286 |
-
)
|
| 287 |
|
| 288 |
class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
|
| 289 |
class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
|
|
@@ -295,6 +200,7 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 295 |
class_labels_input = class_labels_t
|
| 296 |
|
| 297 |
self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
|
|
|
|
| 298 |
for t in self.progress_bar(self.scheduler.timesteps):
|
| 299 |
step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
|
| 300 |
sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
|
|
@@ -329,7 +235,7 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 329 |
sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
|
| 330 |
# JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
|
| 331 |
model_output = -(x_pred - latents) / sigma
|
| 332 |
-
latents = self.scheduler.step(model_output, t, latents).prev_sample
|
| 333 |
|
| 334 |
images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
|
| 335 |
if output_type == "pt":
|
|
@@ -344,3 +250,5 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 344 |
if not return_dict:
|
| 345 |
return (images,)
|
| 346 |
return ImagePipelineOutput(images=images)
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hub custom pipeline: JiTPipeline.
|
| 2 |
+
Load with native Hugging Face diffusers and trust_remote_code=True.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import inspect
|
| 8 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import json
|
|
|
|
| 10 |
from pathlib import Path
|
| 11 |
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
| 12 |
|
| 13 |
import torch
|
|
|
|
| 14 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
|
|
| 15 |
from diffusers.utils.torch_utils import randn_tensor
|
| 16 |
|
|
|
|
| 17 |
RECOMMENDED_NOISE_BY_SIZE = {
|
| 18 |
256: 1.0,
|
| 19 |
512: 2.0,
|
| 20 |
}
|
| 21 |
|
|
|
|
| 22 |
class JiTPipeline(DiffusionPipeline):
|
| 23 |
r"""
|
| 24 |
Pipeline for image generation using JiT (Just image Transformer).
|
|
|
|
| 32 |
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
|
| 33 |
"""
|
| 34 |
|
| 35 |
+
@staticmethod
|
| 36 |
+
def prepare_extra_step_kwargs(
|
| 37 |
+
scheduler,
|
| 38 |
+
generator=None,
|
| 39 |
+
eta: float | None = None,
|
| 40 |
+
):
|
| 41 |
+
kwargs = {}
|
| 42 |
+
step_params = set(inspect.signature(scheduler.step).parameters.keys())
|
| 43 |
+
if "generator" in step_params:
|
| 44 |
+
kwargs["generator"] = generator
|
| 45 |
+
if eta is not None and "eta" in step_params:
|
| 46 |
+
kwargs["eta"] = eta
|
| 47 |
+
return kwargs
|
| 48 |
|
| 49 |
+
model_cpu_offload_seq = "transformer"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
def __init__(
|
| 52 |
self,
|
| 53 |
transformer,
|
| 54 |
+
scheduler,
|
| 55 |
id2label: Optional[Dict[Union[int, str], str]] = None,
|
| 56 |
):
|
| 57 |
super().__init__()
|
| 58 |
scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
|
| 59 |
self.register_modules(transformer=transformer, scheduler=scheduler)
|
|
|
|
| 60 |
self._id2label = self._normalize_id2label(id2label)
|
| 61 |
self.labels = self._build_label2id(self._id2label)
|
| 62 |
+
self._labels_loaded_from_model_index = bool(self._id2label)
|
| 63 |
+
|
| 64 |
+
def _ensure_labels_loaded(self) -> None:
|
| 65 |
+
if self._labels_loaded_from_model_index:
|
| 66 |
+
return
|
| 67 |
+
loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
|
| 68 |
+
if loaded:
|
| 69 |
+
self._id2label = loaded
|
| 70 |
+
self.labels = self._build_label2id(self._id2label)
|
| 71 |
+
self._labels_loaded_from_model_index = True
|
| 72 |
|
| 73 |
@staticmethod
|
| 74 |
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
|
|
|
|
| 77 |
return {int(key): value for key, value in id2label.items()}
|
| 78 |
|
| 79 |
@staticmethod
|
| 80 |
+
def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
|
| 81 |
+
if not variant_path:
|
| 82 |
+
return {}
|
| 83 |
+
variant_dir = Path(variant_path).resolve()
|
| 84 |
+
model_index_path = variant_dir / "model_index.json"
|
| 85 |
if not model_index_path.exists():
|
| 86 |
return {}
|
| 87 |
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
|
|
|
|
| 102 |
|
| 103 |
@property
|
| 104 |
def id2label(self) -> Dict[int, str]:
|
| 105 |
+
self._ensure_labels_loaded()
|
| 106 |
return self._id2label
|
| 107 |
|
| 108 |
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
|
| 109 |
+
self._ensure_labels_loaded()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
label2id = self.labels
|
| 111 |
if not label2id:
|
| 112 |
+
raise ValueError(
|
| 113 |
+
"No English labels loaded. Ensure `id2label` exists in model_index.json."
|
| 114 |
+
)
|
| 115 |
|
| 116 |
if isinstance(label, str):
|
| 117 |
label = [label]
|
|
|
|
| 119 |
missing = [item for item in label if item not in label2id]
|
| 120 |
if missing:
|
| 121 |
preview = ", ".join(list(label2id.keys())[:8])
|
| 122 |
+
raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
|
|
|
|
|
|
|
| 123 |
return [label2id[item] for item in label]
|
| 124 |
|
| 125 |
def _normalize_class_labels(
|
|
|
|
| 154 |
output_type: Optional[str] = "pil",
|
| 155 |
return_dict: bool = True,
|
| 156 |
) -> Union[ImagePipelineOutput, Tuple]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
if num_inference_steps < 2:
|
| 158 |
raise ValueError("num_inference_steps must be >= 2.")
|
| 159 |
+
if output_type not in {"pil", "np", "pt"}:
|
| 160 |
+
raise ValueError("output_type must be one of: 'pil', 'np', 'pt'.")
|
| 161 |
|
| 162 |
class_label_ids = self._normalize_class_labels(class_labels)
|
| 163 |
do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
|
|
|
|
| 174 |
f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
|
| 175 |
)
|
| 176 |
channels = int(self.transformer.config.in_channels)
|
| 177 |
+
null_class_val = int(
|
| 178 |
+
getattr(self.transformer.config, "num_classes", getattr(self.transformer.config, "num_class_embeds", 1000))
|
| 179 |
+
)
|
| 180 |
|
| 181 |
if guidance_scale is None:
|
| 182 |
guidance_scale = 1.0
|
| 183 |
if noise_scale is None:
|
| 184 |
noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
|
| 185 |
|
| 186 |
+
latents = randn_tensor(
|
|
|
|
| 187 |
shape=(batch_size, channels, height, width),
|
| 188 |
+
generator=generator,
|
| 189 |
+
device=self._execution_device,
|
| 190 |
+
dtype=self.transformer.dtype,
|
| 191 |
+
) * noise_scale
|
|
|
|
|
|
|
| 192 |
|
| 193 |
class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
|
| 194 |
class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
|
|
|
|
| 200 |
class_labels_input = class_labels_t
|
| 201 |
|
| 202 |
self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
|
| 203 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator=generator)
|
| 204 |
for t in self.progress_bar(self.scheduler.timesteps):
|
| 205 |
step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
|
| 206 |
sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
|
|
|
|
| 235 |
sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
|
| 236 |
# JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
|
| 237 |
model_output = -(x_pred - latents) / sigma
|
| 238 |
+
latents = self.scheduler.step(model_output, t, latents, **extra_step_kwargs).prev_sample
|
| 239 |
|
| 240 |
images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
|
| 241 |
if output_type == "pt":
|
|
|
|
| 250 |
if not return_dict:
|
| 251 |
return (images,)
|
| 252 |
return ImagePipelineOutput(images=images)
|
| 253 |
+
|
| 254 |
+
JiTPipelineOutput = ImagePipelineOutput
|
JiT-B-32/pipeline.py
CHANGED
|
@@ -1,36 +1,24 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
import importlib
|
| 16 |
import json
|
| 17 |
-
import sys
|
| 18 |
from pathlib import Path
|
| 19 |
-
from typing import Dict, List, Optional, Tuple, Union
|
| 20 |
|
| 21 |
import torch
|
| 22 |
-
|
| 23 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 24 |
-
from diffusers.schedulers import FlowMatchHeunDiscreteScheduler, KarrasDiffusionSchedulers
|
| 25 |
from diffusers.utils.torch_utils import randn_tensor
|
| 26 |
|
| 27 |
-
|
| 28 |
RECOMMENDED_NOISE_BY_SIZE = {
|
| 29 |
256: 1.0,
|
| 30 |
512: 2.0,
|
| 31 |
}
|
| 32 |
|
| 33 |
-
|
| 34 |
class JiTPipeline(DiffusionPipeline):
|
| 35 |
r"""
|
| 36 |
Pipeline for image generation using JiT (Just image Transformer).
|
|
@@ -44,100 +32,43 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 44 |
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
|
| 45 |
"""
|
| 46 |
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
-
def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
|
| 51 |
-
"""Load a self-contained variant folder locally or from the Hub.
|
| 52 |
-
|
| 53 |
-
Examples:
|
| 54 |
-
JiTPipeline.from_pretrained(".")
|
| 55 |
-
JiTPipeline.from_pretrained("./JiT-H-32")
|
| 56 |
-
DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
|
| 57 |
-
"""
|
| 58 |
-
repo_root = Path(__file__).resolve().parent
|
| 59 |
-
|
| 60 |
-
if pretrained_model_name_or_path in (None, "", "."):
|
| 61 |
-
variant = repo_root
|
| 62 |
-
elif (
|
| 63 |
-
isinstance(pretrained_model_name_or_path, str)
|
| 64 |
-
and "/" in pretrained_model_name_or_path
|
| 65 |
-
and not Path(pretrained_model_name_or_path).exists()
|
| 66 |
-
):
|
| 67 |
-
from huggingface_hub import snapshot_download
|
| 68 |
-
|
| 69 |
-
hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
|
| 70 |
-
if subfolder:
|
| 71 |
-
hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
|
| 72 |
-
cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
|
| 73 |
-
variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
|
| 74 |
-
else:
|
| 75 |
-
variant = Path(pretrained_model_name_or_path)
|
| 76 |
-
if not variant.is_absolute():
|
| 77 |
-
candidate = (Path.cwd() / variant).resolve()
|
| 78 |
-
variant = candidate if candidate.exists() else (repo_root / variant).resolve()
|
| 79 |
-
if subfolder:
|
| 80 |
-
variant = variant / subfolder
|
| 81 |
-
|
| 82 |
-
id2label_override = kwargs.pop("id2label", None)
|
| 83 |
-
model_kwargs = dict(kwargs)
|
| 84 |
-
inserted: List[str] = []
|
| 85 |
-
|
| 86 |
-
def _load_component(folder: str, module_name: str, class_name: str):
|
| 87 |
-
comp_dir = variant / folder
|
| 88 |
-
module_path = comp_dir / f"{module_name}.py"
|
| 89 |
-
has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
|
| 90 |
-
if not module_path.exists() or not has_weights:
|
| 91 |
-
return None
|
| 92 |
-
|
| 93 |
-
comp_path = str(comp_dir)
|
| 94 |
-
if comp_path not in sys.path:
|
| 95 |
-
sys.path.insert(0, comp_path)
|
| 96 |
-
inserted.append(comp_path)
|
| 97 |
-
|
| 98 |
-
module = importlib.import_module(module_name)
|
| 99 |
-
component_cls = getattr(module, class_name)
|
| 100 |
-
return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
|
| 101 |
-
|
| 102 |
-
try:
|
| 103 |
-
transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
|
| 104 |
-
try:
|
| 105 |
-
scheduler = FlowMatchHeunDiscreteScheduler.from_pretrained(str(variant), subfolder="scheduler")
|
| 106 |
-
except Exception:
|
| 107 |
-
scheduler = FlowMatchHeunDiscreteScheduler(shift=4.0)
|
| 108 |
-
|
| 109 |
-
if transformer is None:
|
| 110 |
-
raise ValueError(f"No loadable transformer found under {variant}")
|
| 111 |
-
|
| 112 |
-
variant_path = str(variant)
|
| 113 |
-
model_index_path = variant / "model_index.json"
|
| 114 |
-
id2label = id2label_override or cls._read_id2label_from_model_index(model_index_path)
|
| 115 |
-
|
| 116 |
-
pipe = cls(
|
| 117 |
-
transformer=transformer,
|
| 118 |
-
scheduler=scheduler,
|
| 119 |
-
id2label=id2label,
|
| 120 |
-
)
|
| 121 |
-
if variant_path and hasattr(pipe, "register_to_config"):
|
| 122 |
-
pipe.register_to_config(_name_or_path=variant_path)
|
| 123 |
-
return pipe
|
| 124 |
-
finally:
|
| 125 |
-
for comp_path in inserted:
|
| 126 |
-
if comp_path in sys.path:
|
| 127 |
-
sys.path.remove(comp_path)
|
| 128 |
|
| 129 |
def __init__(
|
| 130 |
self,
|
| 131 |
transformer,
|
| 132 |
-
scheduler
|
| 133 |
id2label: Optional[Dict[Union[int, str], str]] = None,
|
| 134 |
):
|
| 135 |
super().__init__()
|
| 136 |
scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
|
| 137 |
self.register_modules(transformer=transformer, scheduler=scheduler)
|
| 138 |
-
|
| 139 |
self._id2label = self._normalize_id2label(id2label)
|
| 140 |
self.labels = self._build_label2id(self._id2label)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
@staticmethod
|
| 143 |
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
|
|
@@ -146,7 +77,11 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 146 |
return {int(key): value for key, value in id2label.items()}
|
| 147 |
|
| 148 |
@staticmethod
|
| 149 |
-
def _read_id2label_from_model_index(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
if not model_index_path.exists():
|
| 151 |
return {}
|
| 152 |
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
|
|
@@ -167,20 +102,16 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 167 |
|
| 168 |
@property
|
| 169 |
def id2label(self) -> Dict[int, str]:
|
| 170 |
-
|
| 171 |
return self._id2label
|
| 172 |
|
| 173 |
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
|
| 174 |
-
|
| 175 |
-
Map ImageNet label strings to class ids.
|
| 176 |
-
|
| 177 |
-
Args:
|
| 178 |
-
label (`str` or `list[str]`):
|
| 179 |
-
One or more English label strings. Each string must match a synonym in `id2label`.
|
| 180 |
-
"""
|
| 181 |
label2id = self.labels
|
| 182 |
if not label2id:
|
| 183 |
-
raise ValueError(
|
|
|
|
|
|
|
| 184 |
|
| 185 |
if isinstance(label, str):
|
| 186 |
label = [label]
|
|
@@ -188,9 +119,7 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 188 |
missing = [item for item in label if item not in label2id]
|
| 189 |
if missing:
|
| 190 |
preview = ", ".join(list(label2id.keys())[:8])
|
| 191 |
-
raise ValueError(
|
| 192 |
-
f"Unknown English label(s): {missing}. Example valid labels: {preview}, ..."
|
| 193 |
-
)
|
| 194 |
return [label2id[item] for item in label]
|
| 195 |
|
| 196 |
def _normalize_class_labels(
|
|
@@ -225,33 +154,10 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 225 |
output_type: Optional[str] = "pil",
|
| 226 |
return_dict: bool = True,
|
| 227 |
) -> Union[ImagePipelineOutput, Tuple]:
|
| 228 |
-
r"""
|
| 229 |
-
Generate class-conditional images.
|
| 230 |
-
|
| 231 |
-
Args:
|
| 232 |
-
class_labels (`int`, `str`, `list[int]`, or `list[str]`):
|
| 233 |
-
ImageNet class indices or human-readable English label strings.
|
| 234 |
-
guidance_scale (`float`, *optional*):
|
| 235 |
-
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
|
| 236 |
-
guidance_interval_min (`float`, defaults to `0.1`):
|
| 237 |
-
Lower bound of the CFG interval in flow time `t in [0, 1]`.
|
| 238 |
-
guidance_interval_max (`float`, defaults to `1.0`):
|
| 239 |
-
Upper bound of the CFG interval in flow time.
|
| 240 |
-
noise_scale (`float`, *optional*):
|
| 241 |
-
Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
|
| 242 |
-
t_eps (`float`, defaults to `5e-2`):
|
| 243 |
-
Epsilon clamp for the `1 - t` denominator, matching JiT source defaults.
|
| 244 |
-
generator (`torch.Generator`, *optional*):
|
| 245 |
-
RNG for reproducibility.
|
| 246 |
-
num_inference_steps (`int`, defaults to `50`):
|
| 247 |
-
Number of solver steps (at least 2).
|
| 248 |
-
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 249 |
-
`"pil"`, `"np"`, or `"pt"`.
|
| 250 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
| 251 |
-
Return [`ImagePipelineOutput`] if True.
|
| 252 |
-
"""
|
| 253 |
if num_inference_steps < 2:
|
| 254 |
raise ValueError("num_inference_steps must be >= 2.")
|
|
|
|
|
|
|
| 255 |
|
| 256 |
class_label_ids = self._normalize_class_labels(class_labels)
|
| 257 |
do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
|
|
@@ -268,22 +174,21 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 268 |
f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
|
| 269 |
)
|
| 270 |
channels = int(self.transformer.config.in_channels)
|
| 271 |
-
null_class_val = int(
|
|
|
|
|
|
|
| 272 |
|
| 273 |
if guidance_scale is None:
|
| 274 |
guidance_scale = 1.0
|
| 275 |
if noise_scale is None:
|
| 276 |
noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
|
| 277 |
|
| 278 |
-
latents = (
|
| 279 |
-
randn_tensor(
|
| 280 |
shape=(batch_size, channels, height, width),
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
* noise_scale
|
| 286 |
-
)
|
| 287 |
|
| 288 |
class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
|
| 289 |
class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
|
|
@@ -295,6 +200,7 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 295 |
class_labels_input = class_labels_t
|
| 296 |
|
| 297 |
self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
|
|
|
|
| 298 |
for t in self.progress_bar(self.scheduler.timesteps):
|
| 299 |
step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
|
| 300 |
sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
|
|
@@ -329,7 +235,7 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 329 |
sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
|
| 330 |
# JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
|
| 331 |
model_output = -(x_pred - latents) / sigma
|
| 332 |
-
latents = self.scheduler.step(model_output, t, latents).prev_sample
|
| 333 |
|
| 334 |
images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
|
| 335 |
if output_type == "pt":
|
|
@@ -344,3 +250,5 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 344 |
if not return_dict:
|
| 345 |
return (images,)
|
| 346 |
return ImagePipelineOutput(images=images)
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hub custom pipeline: JiTPipeline.
|
| 2 |
+
Load with native Hugging Face diffusers and trust_remote_code=True.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import inspect
|
| 8 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import json
|
|
|
|
| 10 |
from pathlib import Path
|
| 11 |
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
| 12 |
|
| 13 |
import torch
|
|
|
|
| 14 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
|
|
| 15 |
from diffusers.utils.torch_utils import randn_tensor
|
| 16 |
|
|
|
|
| 17 |
RECOMMENDED_NOISE_BY_SIZE = {
|
| 18 |
256: 1.0,
|
| 19 |
512: 2.0,
|
| 20 |
}
|
| 21 |
|
|
|
|
| 22 |
class JiTPipeline(DiffusionPipeline):
|
| 23 |
r"""
|
| 24 |
Pipeline for image generation using JiT (Just image Transformer).
|
|
|
|
| 32 |
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
|
| 33 |
"""
|
| 34 |
|
| 35 |
+
@staticmethod
|
| 36 |
+
def prepare_extra_step_kwargs(
|
| 37 |
+
scheduler,
|
| 38 |
+
generator=None,
|
| 39 |
+
eta: float | None = None,
|
| 40 |
+
):
|
| 41 |
+
kwargs = {}
|
| 42 |
+
step_params = set(inspect.signature(scheduler.step).parameters.keys())
|
| 43 |
+
if "generator" in step_params:
|
| 44 |
+
kwargs["generator"] = generator
|
| 45 |
+
if eta is not None and "eta" in step_params:
|
| 46 |
+
kwargs["eta"] = eta
|
| 47 |
+
return kwargs
|
| 48 |
|
| 49 |
+
model_cpu_offload_seq = "transformer"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
def __init__(
|
| 52 |
self,
|
| 53 |
transformer,
|
| 54 |
+
scheduler,
|
| 55 |
id2label: Optional[Dict[Union[int, str], str]] = None,
|
| 56 |
):
|
| 57 |
super().__init__()
|
| 58 |
scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
|
| 59 |
self.register_modules(transformer=transformer, scheduler=scheduler)
|
|
|
|
| 60 |
self._id2label = self._normalize_id2label(id2label)
|
| 61 |
self.labels = self._build_label2id(self._id2label)
|
| 62 |
+
self._labels_loaded_from_model_index = bool(self._id2label)
|
| 63 |
+
|
| 64 |
+
def _ensure_labels_loaded(self) -> None:
|
| 65 |
+
if self._labels_loaded_from_model_index:
|
| 66 |
+
return
|
| 67 |
+
loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
|
| 68 |
+
if loaded:
|
| 69 |
+
self._id2label = loaded
|
| 70 |
+
self.labels = self._build_label2id(self._id2label)
|
| 71 |
+
self._labels_loaded_from_model_index = True
|
| 72 |
|
| 73 |
@staticmethod
|
| 74 |
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
|
|
|
|
| 77 |
return {int(key): value for key, value in id2label.items()}
|
| 78 |
|
| 79 |
@staticmethod
|
| 80 |
+
def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
|
| 81 |
+
if not variant_path:
|
| 82 |
+
return {}
|
| 83 |
+
variant_dir = Path(variant_path).resolve()
|
| 84 |
+
model_index_path = variant_dir / "model_index.json"
|
| 85 |
if not model_index_path.exists():
|
| 86 |
return {}
|
| 87 |
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
|
|
|
|
| 102 |
|
| 103 |
@property
|
| 104 |
def id2label(self) -> Dict[int, str]:
|
| 105 |
+
self._ensure_labels_loaded()
|
| 106 |
return self._id2label
|
| 107 |
|
| 108 |
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
|
| 109 |
+
self._ensure_labels_loaded()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
label2id = self.labels
|
| 111 |
if not label2id:
|
| 112 |
+
raise ValueError(
|
| 113 |
+
"No English labels loaded. Ensure `id2label` exists in model_index.json."
|
| 114 |
+
)
|
| 115 |
|
| 116 |
if isinstance(label, str):
|
| 117 |
label = [label]
|
|
|
|
| 119 |
missing = [item for item in label if item not in label2id]
|
| 120 |
if missing:
|
| 121 |
preview = ", ".join(list(label2id.keys())[:8])
|
| 122 |
+
raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
|
|
|
|
|
|
|
| 123 |
return [label2id[item] for item in label]
|
| 124 |
|
| 125 |
def _normalize_class_labels(
|
|
|
|
| 154 |
output_type: Optional[str] = "pil",
|
| 155 |
return_dict: bool = True,
|
| 156 |
) -> Union[ImagePipelineOutput, Tuple]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
if num_inference_steps < 2:
|
| 158 |
raise ValueError("num_inference_steps must be >= 2.")
|
| 159 |
+
if output_type not in {"pil", "np", "pt"}:
|
| 160 |
+
raise ValueError("output_type must be one of: 'pil', 'np', 'pt'.")
|
| 161 |
|
| 162 |
class_label_ids = self._normalize_class_labels(class_labels)
|
| 163 |
do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
|
|
|
|
| 174 |
f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
|
| 175 |
)
|
| 176 |
channels = int(self.transformer.config.in_channels)
|
| 177 |
+
null_class_val = int(
|
| 178 |
+
getattr(self.transformer.config, "num_classes", getattr(self.transformer.config, "num_class_embeds", 1000))
|
| 179 |
+
)
|
| 180 |
|
| 181 |
if guidance_scale is None:
|
| 182 |
guidance_scale = 1.0
|
| 183 |
if noise_scale is None:
|
| 184 |
noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
|
| 185 |
|
| 186 |
+
latents = randn_tensor(
|
|
|
|
| 187 |
shape=(batch_size, channels, height, width),
|
| 188 |
+
generator=generator,
|
| 189 |
+
device=self._execution_device,
|
| 190 |
+
dtype=self.transformer.dtype,
|
| 191 |
+
) * noise_scale
|
|
|
|
|
|
|
| 192 |
|
| 193 |
class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
|
| 194 |
class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
|
|
|
|
| 200 |
class_labels_input = class_labels_t
|
| 201 |
|
| 202 |
self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
|
| 203 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator=generator)
|
| 204 |
for t in self.progress_bar(self.scheduler.timesteps):
|
| 205 |
step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
|
| 206 |
sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
|
|
|
|
| 235 |
sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
|
| 236 |
# JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
|
| 237 |
model_output = -(x_pred - latents) / sigma
|
| 238 |
+
latents = self.scheduler.step(model_output, t, latents, **extra_step_kwargs).prev_sample
|
| 239 |
|
| 240 |
images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
|
| 241 |
if output_type == "pt":
|
|
|
|
| 250 |
if not return_dict:
|
| 251 |
return (images,)
|
| 252 |
return ImagePipelineOutput(images=images)
|
| 253 |
+
|
| 254 |
+
JiTPipelineOutput = ImagePipelineOutput
|
JiT-H-16/pipeline.py
CHANGED
|
@@ -1,36 +1,24 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
import importlib
|
| 16 |
import json
|
| 17 |
-
import sys
|
| 18 |
from pathlib import Path
|
| 19 |
-
from typing import Dict, List, Optional, Tuple, Union
|
| 20 |
|
| 21 |
import torch
|
| 22 |
-
|
| 23 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 24 |
-
from diffusers.schedulers import FlowMatchHeunDiscreteScheduler, KarrasDiffusionSchedulers
|
| 25 |
from diffusers.utils.torch_utils import randn_tensor
|
| 26 |
|
| 27 |
-
|
| 28 |
RECOMMENDED_NOISE_BY_SIZE = {
|
| 29 |
256: 1.0,
|
| 30 |
512: 2.0,
|
| 31 |
}
|
| 32 |
|
| 33 |
-
|
| 34 |
class JiTPipeline(DiffusionPipeline):
|
| 35 |
r"""
|
| 36 |
Pipeline for image generation using JiT (Just image Transformer).
|
|
@@ -44,100 +32,43 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 44 |
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
|
| 45 |
"""
|
| 46 |
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
-
def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
|
| 51 |
-
"""Load a self-contained variant folder locally or from the Hub.
|
| 52 |
-
|
| 53 |
-
Examples:
|
| 54 |
-
JiTPipeline.from_pretrained(".")
|
| 55 |
-
JiTPipeline.from_pretrained("./JiT-H-32")
|
| 56 |
-
DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
|
| 57 |
-
"""
|
| 58 |
-
repo_root = Path(__file__).resolve().parent
|
| 59 |
-
|
| 60 |
-
if pretrained_model_name_or_path in (None, "", "."):
|
| 61 |
-
variant = repo_root
|
| 62 |
-
elif (
|
| 63 |
-
isinstance(pretrained_model_name_or_path, str)
|
| 64 |
-
and "/" in pretrained_model_name_or_path
|
| 65 |
-
and not Path(pretrained_model_name_or_path).exists()
|
| 66 |
-
):
|
| 67 |
-
from huggingface_hub import snapshot_download
|
| 68 |
-
|
| 69 |
-
hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
|
| 70 |
-
if subfolder:
|
| 71 |
-
hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
|
| 72 |
-
cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
|
| 73 |
-
variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
|
| 74 |
-
else:
|
| 75 |
-
variant = Path(pretrained_model_name_or_path)
|
| 76 |
-
if not variant.is_absolute():
|
| 77 |
-
candidate = (Path.cwd() / variant).resolve()
|
| 78 |
-
variant = candidate if candidate.exists() else (repo_root / variant).resolve()
|
| 79 |
-
if subfolder:
|
| 80 |
-
variant = variant / subfolder
|
| 81 |
-
|
| 82 |
-
id2label_override = kwargs.pop("id2label", None)
|
| 83 |
-
model_kwargs = dict(kwargs)
|
| 84 |
-
inserted: List[str] = []
|
| 85 |
-
|
| 86 |
-
def _load_component(folder: str, module_name: str, class_name: str):
|
| 87 |
-
comp_dir = variant / folder
|
| 88 |
-
module_path = comp_dir / f"{module_name}.py"
|
| 89 |
-
has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
|
| 90 |
-
if not module_path.exists() or not has_weights:
|
| 91 |
-
return None
|
| 92 |
-
|
| 93 |
-
comp_path = str(comp_dir)
|
| 94 |
-
if comp_path not in sys.path:
|
| 95 |
-
sys.path.insert(0, comp_path)
|
| 96 |
-
inserted.append(comp_path)
|
| 97 |
-
|
| 98 |
-
module = importlib.import_module(module_name)
|
| 99 |
-
component_cls = getattr(module, class_name)
|
| 100 |
-
return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
|
| 101 |
-
|
| 102 |
-
try:
|
| 103 |
-
transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
|
| 104 |
-
try:
|
| 105 |
-
scheduler = FlowMatchHeunDiscreteScheduler.from_pretrained(str(variant), subfolder="scheduler")
|
| 106 |
-
except Exception:
|
| 107 |
-
scheduler = FlowMatchHeunDiscreteScheduler(shift=4.0)
|
| 108 |
-
|
| 109 |
-
if transformer is None:
|
| 110 |
-
raise ValueError(f"No loadable transformer found under {variant}")
|
| 111 |
-
|
| 112 |
-
variant_path = str(variant)
|
| 113 |
-
model_index_path = variant / "model_index.json"
|
| 114 |
-
id2label = id2label_override or cls._read_id2label_from_model_index(model_index_path)
|
| 115 |
-
|
| 116 |
-
pipe = cls(
|
| 117 |
-
transformer=transformer,
|
| 118 |
-
scheduler=scheduler,
|
| 119 |
-
id2label=id2label,
|
| 120 |
-
)
|
| 121 |
-
if variant_path and hasattr(pipe, "register_to_config"):
|
| 122 |
-
pipe.register_to_config(_name_or_path=variant_path)
|
| 123 |
-
return pipe
|
| 124 |
-
finally:
|
| 125 |
-
for comp_path in inserted:
|
| 126 |
-
if comp_path in sys.path:
|
| 127 |
-
sys.path.remove(comp_path)
|
| 128 |
|
| 129 |
def __init__(
|
| 130 |
self,
|
| 131 |
transformer,
|
| 132 |
-
scheduler
|
| 133 |
id2label: Optional[Dict[Union[int, str], str]] = None,
|
| 134 |
):
|
| 135 |
super().__init__()
|
| 136 |
scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
|
| 137 |
self.register_modules(transformer=transformer, scheduler=scheduler)
|
| 138 |
-
|
| 139 |
self._id2label = self._normalize_id2label(id2label)
|
| 140 |
self.labels = self._build_label2id(self._id2label)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
@staticmethod
|
| 143 |
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
|
|
@@ -146,7 +77,11 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 146 |
return {int(key): value for key, value in id2label.items()}
|
| 147 |
|
| 148 |
@staticmethod
|
| 149 |
-
def _read_id2label_from_model_index(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
if not model_index_path.exists():
|
| 151 |
return {}
|
| 152 |
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
|
|
@@ -167,20 +102,16 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 167 |
|
| 168 |
@property
|
| 169 |
def id2label(self) -> Dict[int, str]:
|
| 170 |
-
|
| 171 |
return self._id2label
|
| 172 |
|
| 173 |
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
|
| 174 |
-
|
| 175 |
-
Map ImageNet label strings to class ids.
|
| 176 |
-
|
| 177 |
-
Args:
|
| 178 |
-
label (`str` or `list[str]`):
|
| 179 |
-
One or more English label strings. Each string must match a synonym in `id2label`.
|
| 180 |
-
"""
|
| 181 |
label2id = self.labels
|
| 182 |
if not label2id:
|
| 183 |
-
raise ValueError(
|
|
|
|
|
|
|
| 184 |
|
| 185 |
if isinstance(label, str):
|
| 186 |
label = [label]
|
|
@@ -188,9 +119,7 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 188 |
missing = [item for item in label if item not in label2id]
|
| 189 |
if missing:
|
| 190 |
preview = ", ".join(list(label2id.keys())[:8])
|
| 191 |
-
raise ValueError(
|
| 192 |
-
f"Unknown English label(s): {missing}. Example valid labels: {preview}, ..."
|
| 193 |
-
)
|
| 194 |
return [label2id[item] for item in label]
|
| 195 |
|
| 196 |
def _normalize_class_labels(
|
|
@@ -225,33 +154,10 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 225 |
output_type: Optional[str] = "pil",
|
| 226 |
return_dict: bool = True,
|
| 227 |
) -> Union[ImagePipelineOutput, Tuple]:
|
| 228 |
-
r"""
|
| 229 |
-
Generate class-conditional images.
|
| 230 |
-
|
| 231 |
-
Args:
|
| 232 |
-
class_labels (`int`, `str`, `list[int]`, or `list[str]`):
|
| 233 |
-
ImageNet class indices or human-readable English label strings.
|
| 234 |
-
guidance_scale (`float`, *optional*):
|
| 235 |
-
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
|
| 236 |
-
guidance_interval_min (`float`, defaults to `0.1`):
|
| 237 |
-
Lower bound of the CFG interval in flow time `t in [0, 1]`.
|
| 238 |
-
guidance_interval_max (`float`, defaults to `1.0`):
|
| 239 |
-
Upper bound of the CFG interval in flow time.
|
| 240 |
-
noise_scale (`float`, *optional*):
|
| 241 |
-
Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
|
| 242 |
-
t_eps (`float`, defaults to `5e-2`):
|
| 243 |
-
Epsilon clamp for the `1 - t` denominator, matching JiT source defaults.
|
| 244 |
-
generator (`torch.Generator`, *optional*):
|
| 245 |
-
RNG for reproducibility.
|
| 246 |
-
num_inference_steps (`int`, defaults to `50`):
|
| 247 |
-
Number of solver steps (at least 2).
|
| 248 |
-
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 249 |
-
`"pil"`, `"np"`, or `"pt"`.
|
| 250 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
| 251 |
-
Return [`ImagePipelineOutput`] if True.
|
| 252 |
-
"""
|
| 253 |
if num_inference_steps < 2:
|
| 254 |
raise ValueError("num_inference_steps must be >= 2.")
|
|
|
|
|
|
|
| 255 |
|
| 256 |
class_label_ids = self._normalize_class_labels(class_labels)
|
| 257 |
do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
|
|
@@ -268,22 +174,21 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 268 |
f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
|
| 269 |
)
|
| 270 |
channels = int(self.transformer.config.in_channels)
|
| 271 |
-
null_class_val = int(
|
|
|
|
|
|
|
| 272 |
|
| 273 |
if guidance_scale is None:
|
| 274 |
guidance_scale = 1.0
|
| 275 |
if noise_scale is None:
|
| 276 |
noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
|
| 277 |
|
| 278 |
-
latents = (
|
| 279 |
-
randn_tensor(
|
| 280 |
shape=(batch_size, channels, height, width),
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
* noise_scale
|
| 286 |
-
)
|
| 287 |
|
| 288 |
class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
|
| 289 |
class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
|
|
@@ -295,6 +200,7 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 295 |
class_labels_input = class_labels_t
|
| 296 |
|
| 297 |
self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
|
|
|
|
| 298 |
for t in self.progress_bar(self.scheduler.timesteps):
|
| 299 |
step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
|
| 300 |
sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
|
|
@@ -329,7 +235,7 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 329 |
sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
|
| 330 |
# JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
|
| 331 |
model_output = -(x_pred - latents) / sigma
|
| 332 |
-
latents = self.scheduler.step(model_output, t, latents).prev_sample
|
| 333 |
|
| 334 |
images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
|
| 335 |
if output_type == "pt":
|
|
@@ -344,3 +250,5 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 344 |
if not return_dict:
|
| 345 |
return (images,)
|
| 346 |
return ImagePipelineOutput(images=images)
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hub custom pipeline: JiTPipeline.
|
| 2 |
+
Load with native Hugging Face diffusers and trust_remote_code=True.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import inspect
|
| 8 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import json
|
|
|
|
| 10 |
from pathlib import Path
|
| 11 |
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
| 12 |
|
| 13 |
import torch
|
|
|
|
| 14 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
|
|
| 15 |
from diffusers.utils.torch_utils import randn_tensor
|
| 16 |
|
|
|
|
| 17 |
RECOMMENDED_NOISE_BY_SIZE = {
|
| 18 |
256: 1.0,
|
| 19 |
512: 2.0,
|
| 20 |
}
|
| 21 |
|
|
|
|
| 22 |
class JiTPipeline(DiffusionPipeline):
|
| 23 |
r"""
|
| 24 |
Pipeline for image generation using JiT (Just image Transformer).
|
|
|
|
| 32 |
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
|
| 33 |
"""
|
| 34 |
|
| 35 |
+
@staticmethod
|
| 36 |
+
def prepare_extra_step_kwargs(
|
| 37 |
+
scheduler,
|
| 38 |
+
generator=None,
|
| 39 |
+
eta: float | None = None,
|
| 40 |
+
):
|
| 41 |
+
kwargs = {}
|
| 42 |
+
step_params = set(inspect.signature(scheduler.step).parameters.keys())
|
| 43 |
+
if "generator" in step_params:
|
| 44 |
+
kwargs["generator"] = generator
|
| 45 |
+
if eta is not None and "eta" in step_params:
|
| 46 |
+
kwargs["eta"] = eta
|
| 47 |
+
return kwargs
|
| 48 |
|
| 49 |
+
model_cpu_offload_seq = "transformer"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
def __init__(
|
| 52 |
self,
|
| 53 |
transformer,
|
| 54 |
+
scheduler,
|
| 55 |
id2label: Optional[Dict[Union[int, str], str]] = None,
|
| 56 |
):
|
| 57 |
super().__init__()
|
| 58 |
scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
|
| 59 |
self.register_modules(transformer=transformer, scheduler=scheduler)
|
|
|
|
| 60 |
self._id2label = self._normalize_id2label(id2label)
|
| 61 |
self.labels = self._build_label2id(self._id2label)
|
| 62 |
+
self._labels_loaded_from_model_index = bool(self._id2label)
|
| 63 |
+
|
| 64 |
+
def _ensure_labels_loaded(self) -> None:
|
| 65 |
+
if self._labels_loaded_from_model_index:
|
| 66 |
+
return
|
| 67 |
+
loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
|
| 68 |
+
if loaded:
|
| 69 |
+
self._id2label = loaded
|
| 70 |
+
self.labels = self._build_label2id(self._id2label)
|
| 71 |
+
self._labels_loaded_from_model_index = True
|
| 72 |
|
| 73 |
@staticmethod
|
| 74 |
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
|
|
|
|
| 77 |
return {int(key): value for key, value in id2label.items()}
|
| 78 |
|
| 79 |
@staticmethod
|
| 80 |
+
def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
|
| 81 |
+
if not variant_path:
|
| 82 |
+
return {}
|
| 83 |
+
variant_dir = Path(variant_path).resolve()
|
| 84 |
+
model_index_path = variant_dir / "model_index.json"
|
| 85 |
if not model_index_path.exists():
|
| 86 |
return {}
|
| 87 |
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
|
|
|
|
| 102 |
|
| 103 |
@property
|
| 104 |
def id2label(self) -> Dict[int, str]:
|
| 105 |
+
self._ensure_labels_loaded()
|
| 106 |
return self._id2label
|
| 107 |
|
| 108 |
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
|
| 109 |
+
self._ensure_labels_loaded()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
label2id = self.labels
|
| 111 |
if not label2id:
|
| 112 |
+
raise ValueError(
|
| 113 |
+
"No English labels loaded. Ensure `id2label` exists in model_index.json."
|
| 114 |
+
)
|
| 115 |
|
| 116 |
if isinstance(label, str):
|
| 117 |
label = [label]
|
|
|
|
| 119 |
missing = [item for item in label if item not in label2id]
|
| 120 |
if missing:
|
| 121 |
preview = ", ".join(list(label2id.keys())[:8])
|
| 122 |
+
raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
|
|
|
|
|
|
|
| 123 |
return [label2id[item] for item in label]
|
| 124 |
|
| 125 |
def _normalize_class_labels(
|
|
|
|
| 154 |
output_type: Optional[str] = "pil",
|
| 155 |
return_dict: bool = True,
|
| 156 |
) -> Union[ImagePipelineOutput, Tuple]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
if num_inference_steps < 2:
|
| 158 |
raise ValueError("num_inference_steps must be >= 2.")
|
| 159 |
+
if output_type not in {"pil", "np", "pt"}:
|
| 160 |
+
raise ValueError("output_type must be one of: 'pil', 'np', 'pt'.")
|
| 161 |
|
| 162 |
class_label_ids = self._normalize_class_labels(class_labels)
|
| 163 |
do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
|
|
|
|
| 174 |
f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
|
| 175 |
)
|
| 176 |
channels = int(self.transformer.config.in_channels)
|
| 177 |
+
null_class_val = int(
|
| 178 |
+
getattr(self.transformer.config, "num_classes", getattr(self.transformer.config, "num_class_embeds", 1000))
|
| 179 |
+
)
|
| 180 |
|
| 181 |
if guidance_scale is None:
|
| 182 |
guidance_scale = 1.0
|
| 183 |
if noise_scale is None:
|
| 184 |
noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
|
| 185 |
|
| 186 |
+
latents = randn_tensor(
|
|
|
|
| 187 |
shape=(batch_size, channels, height, width),
|
| 188 |
+
generator=generator,
|
| 189 |
+
device=self._execution_device,
|
| 190 |
+
dtype=self.transformer.dtype,
|
| 191 |
+
) * noise_scale
|
|
|
|
|
|
|
| 192 |
|
| 193 |
class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
|
| 194 |
class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
|
|
|
|
| 200 |
class_labels_input = class_labels_t
|
| 201 |
|
| 202 |
self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
|
| 203 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator=generator)
|
| 204 |
for t in self.progress_bar(self.scheduler.timesteps):
|
| 205 |
step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
|
| 206 |
sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
|
|
|
|
| 235 |
sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
|
| 236 |
# JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
|
| 237 |
model_output = -(x_pred - latents) / sigma
|
| 238 |
+
latents = self.scheduler.step(model_output, t, latents, **extra_step_kwargs).prev_sample
|
| 239 |
|
| 240 |
images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
|
| 241 |
if output_type == "pt":
|
|
|
|
| 250 |
if not return_dict:
|
| 251 |
return (images,)
|
| 252 |
return ImagePipelineOutput(images=images)
|
| 253 |
+
|
| 254 |
+
JiTPipelineOutput = ImagePipelineOutput
|
JiT-H-32/pipeline.py
CHANGED
|
@@ -1,36 +1,24 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
import importlib
|
| 16 |
import json
|
| 17 |
-
import sys
|
| 18 |
from pathlib import Path
|
| 19 |
-
from typing import Dict, List, Optional, Tuple, Union
|
| 20 |
|
| 21 |
import torch
|
| 22 |
-
|
| 23 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 24 |
-
from diffusers.schedulers import FlowMatchHeunDiscreteScheduler, KarrasDiffusionSchedulers
|
| 25 |
from diffusers.utils.torch_utils import randn_tensor
|
| 26 |
|
| 27 |
-
|
| 28 |
RECOMMENDED_NOISE_BY_SIZE = {
|
| 29 |
256: 1.0,
|
| 30 |
512: 2.0,
|
| 31 |
}
|
| 32 |
|
| 33 |
-
|
| 34 |
class JiTPipeline(DiffusionPipeline):
|
| 35 |
r"""
|
| 36 |
Pipeline for image generation using JiT (Just image Transformer).
|
|
@@ -44,100 +32,43 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 44 |
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
|
| 45 |
"""
|
| 46 |
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
-
def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
|
| 51 |
-
"""Load a self-contained variant folder locally or from the Hub.
|
| 52 |
-
|
| 53 |
-
Examples:
|
| 54 |
-
JiTPipeline.from_pretrained(".")
|
| 55 |
-
JiTPipeline.from_pretrained("./JiT-H-32")
|
| 56 |
-
DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
|
| 57 |
-
"""
|
| 58 |
-
repo_root = Path(__file__).resolve().parent
|
| 59 |
-
|
| 60 |
-
if pretrained_model_name_or_path in (None, "", "."):
|
| 61 |
-
variant = repo_root
|
| 62 |
-
elif (
|
| 63 |
-
isinstance(pretrained_model_name_or_path, str)
|
| 64 |
-
and "/" in pretrained_model_name_or_path
|
| 65 |
-
and not Path(pretrained_model_name_or_path).exists()
|
| 66 |
-
):
|
| 67 |
-
from huggingface_hub import snapshot_download
|
| 68 |
-
|
| 69 |
-
hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
|
| 70 |
-
if subfolder:
|
| 71 |
-
hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
|
| 72 |
-
cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
|
| 73 |
-
variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
|
| 74 |
-
else:
|
| 75 |
-
variant = Path(pretrained_model_name_or_path)
|
| 76 |
-
if not variant.is_absolute():
|
| 77 |
-
candidate = (Path.cwd() / variant).resolve()
|
| 78 |
-
variant = candidate if candidate.exists() else (repo_root / variant).resolve()
|
| 79 |
-
if subfolder:
|
| 80 |
-
variant = variant / subfolder
|
| 81 |
-
|
| 82 |
-
id2label_override = kwargs.pop("id2label", None)
|
| 83 |
-
model_kwargs = dict(kwargs)
|
| 84 |
-
inserted: List[str] = []
|
| 85 |
-
|
| 86 |
-
def _load_component(folder: str, module_name: str, class_name: str):
|
| 87 |
-
comp_dir = variant / folder
|
| 88 |
-
module_path = comp_dir / f"{module_name}.py"
|
| 89 |
-
has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
|
| 90 |
-
if not module_path.exists() or not has_weights:
|
| 91 |
-
return None
|
| 92 |
-
|
| 93 |
-
comp_path = str(comp_dir)
|
| 94 |
-
if comp_path not in sys.path:
|
| 95 |
-
sys.path.insert(0, comp_path)
|
| 96 |
-
inserted.append(comp_path)
|
| 97 |
-
|
| 98 |
-
module = importlib.import_module(module_name)
|
| 99 |
-
component_cls = getattr(module, class_name)
|
| 100 |
-
return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
|
| 101 |
-
|
| 102 |
-
try:
|
| 103 |
-
transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
|
| 104 |
-
try:
|
| 105 |
-
scheduler = FlowMatchHeunDiscreteScheduler.from_pretrained(str(variant), subfolder="scheduler")
|
| 106 |
-
except Exception:
|
| 107 |
-
scheduler = FlowMatchHeunDiscreteScheduler(shift=4.0)
|
| 108 |
-
|
| 109 |
-
if transformer is None:
|
| 110 |
-
raise ValueError(f"No loadable transformer found under {variant}")
|
| 111 |
-
|
| 112 |
-
variant_path = str(variant)
|
| 113 |
-
model_index_path = variant / "model_index.json"
|
| 114 |
-
id2label = id2label_override or cls._read_id2label_from_model_index(model_index_path)
|
| 115 |
-
|
| 116 |
-
pipe = cls(
|
| 117 |
-
transformer=transformer,
|
| 118 |
-
scheduler=scheduler,
|
| 119 |
-
id2label=id2label,
|
| 120 |
-
)
|
| 121 |
-
if variant_path and hasattr(pipe, "register_to_config"):
|
| 122 |
-
pipe.register_to_config(_name_or_path=variant_path)
|
| 123 |
-
return pipe
|
| 124 |
-
finally:
|
| 125 |
-
for comp_path in inserted:
|
| 126 |
-
if comp_path in sys.path:
|
| 127 |
-
sys.path.remove(comp_path)
|
| 128 |
|
| 129 |
def __init__(
|
| 130 |
self,
|
| 131 |
transformer,
|
| 132 |
-
scheduler
|
| 133 |
id2label: Optional[Dict[Union[int, str], str]] = None,
|
| 134 |
):
|
| 135 |
super().__init__()
|
| 136 |
scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
|
| 137 |
self.register_modules(transformer=transformer, scheduler=scheduler)
|
| 138 |
-
|
| 139 |
self._id2label = self._normalize_id2label(id2label)
|
| 140 |
self.labels = self._build_label2id(self._id2label)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
@staticmethod
|
| 143 |
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
|
|
@@ -146,7 +77,11 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 146 |
return {int(key): value for key, value in id2label.items()}
|
| 147 |
|
| 148 |
@staticmethod
|
| 149 |
-
def _read_id2label_from_model_index(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
if not model_index_path.exists():
|
| 151 |
return {}
|
| 152 |
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
|
|
@@ -167,20 +102,16 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 167 |
|
| 168 |
@property
|
| 169 |
def id2label(self) -> Dict[int, str]:
|
| 170 |
-
|
| 171 |
return self._id2label
|
| 172 |
|
| 173 |
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
|
| 174 |
-
|
| 175 |
-
Map ImageNet label strings to class ids.
|
| 176 |
-
|
| 177 |
-
Args:
|
| 178 |
-
label (`str` or `list[str]`):
|
| 179 |
-
One or more English label strings. Each string must match a synonym in `id2label`.
|
| 180 |
-
"""
|
| 181 |
label2id = self.labels
|
| 182 |
if not label2id:
|
| 183 |
-
raise ValueError(
|
|
|
|
|
|
|
| 184 |
|
| 185 |
if isinstance(label, str):
|
| 186 |
label = [label]
|
|
@@ -188,9 +119,7 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 188 |
missing = [item for item in label if item not in label2id]
|
| 189 |
if missing:
|
| 190 |
preview = ", ".join(list(label2id.keys())[:8])
|
| 191 |
-
raise ValueError(
|
| 192 |
-
f"Unknown English label(s): {missing}. Example valid labels: {preview}, ..."
|
| 193 |
-
)
|
| 194 |
return [label2id[item] for item in label]
|
| 195 |
|
| 196 |
def _normalize_class_labels(
|
|
@@ -225,33 +154,10 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 225 |
output_type: Optional[str] = "pil",
|
| 226 |
return_dict: bool = True,
|
| 227 |
) -> Union[ImagePipelineOutput, Tuple]:
|
| 228 |
-
r"""
|
| 229 |
-
Generate class-conditional images.
|
| 230 |
-
|
| 231 |
-
Args:
|
| 232 |
-
class_labels (`int`, `str`, `list[int]`, or `list[str]`):
|
| 233 |
-
ImageNet class indices or human-readable English label strings.
|
| 234 |
-
guidance_scale (`float`, *optional*):
|
| 235 |
-
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
|
| 236 |
-
guidance_interval_min (`float`, defaults to `0.1`):
|
| 237 |
-
Lower bound of the CFG interval in flow time `t in [0, 1]`.
|
| 238 |
-
guidance_interval_max (`float`, defaults to `1.0`):
|
| 239 |
-
Upper bound of the CFG interval in flow time.
|
| 240 |
-
noise_scale (`float`, *optional*):
|
| 241 |
-
Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
|
| 242 |
-
t_eps (`float`, defaults to `5e-2`):
|
| 243 |
-
Epsilon clamp for the `1 - t` denominator, matching JiT source defaults.
|
| 244 |
-
generator (`torch.Generator`, *optional*):
|
| 245 |
-
RNG for reproducibility.
|
| 246 |
-
num_inference_steps (`int`, defaults to `50`):
|
| 247 |
-
Number of solver steps (at least 2).
|
| 248 |
-
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 249 |
-
`"pil"`, `"np"`, or `"pt"`.
|
| 250 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
| 251 |
-
Return [`ImagePipelineOutput`] if True.
|
| 252 |
-
"""
|
| 253 |
if num_inference_steps < 2:
|
| 254 |
raise ValueError("num_inference_steps must be >= 2.")
|
|
|
|
|
|
|
| 255 |
|
| 256 |
class_label_ids = self._normalize_class_labels(class_labels)
|
| 257 |
do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
|
|
@@ -268,22 +174,21 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 268 |
f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
|
| 269 |
)
|
| 270 |
channels = int(self.transformer.config.in_channels)
|
| 271 |
-
null_class_val = int(
|
|
|
|
|
|
|
| 272 |
|
| 273 |
if guidance_scale is None:
|
| 274 |
guidance_scale = 1.0
|
| 275 |
if noise_scale is None:
|
| 276 |
noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
|
| 277 |
|
| 278 |
-
latents = (
|
| 279 |
-
randn_tensor(
|
| 280 |
shape=(batch_size, channels, height, width),
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
* noise_scale
|
| 286 |
-
)
|
| 287 |
|
| 288 |
class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
|
| 289 |
class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
|
|
@@ -295,6 +200,7 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 295 |
class_labels_input = class_labels_t
|
| 296 |
|
| 297 |
self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
|
|
|
|
| 298 |
for t in self.progress_bar(self.scheduler.timesteps):
|
| 299 |
step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
|
| 300 |
sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
|
|
@@ -329,7 +235,7 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 329 |
sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
|
| 330 |
# JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
|
| 331 |
model_output = -(x_pred - latents) / sigma
|
| 332 |
-
latents = self.scheduler.step(model_output, t, latents).prev_sample
|
| 333 |
|
| 334 |
images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
|
| 335 |
if output_type == "pt":
|
|
@@ -344,3 +250,5 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 344 |
if not return_dict:
|
| 345 |
return (images,)
|
| 346 |
return ImagePipelineOutput(images=images)
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hub custom pipeline: JiTPipeline.
|
| 2 |
+
Load with native Hugging Face diffusers and trust_remote_code=True.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import inspect
|
| 8 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import json
|
|
|
|
| 10 |
from pathlib import Path
|
| 11 |
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
| 12 |
|
| 13 |
import torch
|
|
|
|
| 14 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
|
|
| 15 |
from diffusers.utils.torch_utils import randn_tensor
|
| 16 |
|
|
|
|
| 17 |
RECOMMENDED_NOISE_BY_SIZE = {
|
| 18 |
256: 1.0,
|
| 19 |
512: 2.0,
|
| 20 |
}
|
| 21 |
|
|
|
|
| 22 |
class JiTPipeline(DiffusionPipeline):
|
| 23 |
r"""
|
| 24 |
Pipeline for image generation using JiT (Just image Transformer).
|
|
|
|
| 32 |
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
|
| 33 |
"""
|
| 34 |
|
| 35 |
+
@staticmethod
|
| 36 |
+
def prepare_extra_step_kwargs(
|
| 37 |
+
scheduler,
|
| 38 |
+
generator=None,
|
| 39 |
+
eta: float | None = None,
|
| 40 |
+
):
|
| 41 |
+
kwargs = {}
|
| 42 |
+
step_params = set(inspect.signature(scheduler.step).parameters.keys())
|
| 43 |
+
if "generator" in step_params:
|
| 44 |
+
kwargs["generator"] = generator
|
| 45 |
+
if eta is not None and "eta" in step_params:
|
| 46 |
+
kwargs["eta"] = eta
|
| 47 |
+
return kwargs
|
| 48 |
|
| 49 |
+
model_cpu_offload_seq = "transformer"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
def __init__(
|
| 52 |
self,
|
| 53 |
transformer,
|
| 54 |
+
scheduler,
|
| 55 |
id2label: Optional[Dict[Union[int, str], str]] = None,
|
| 56 |
):
|
| 57 |
super().__init__()
|
| 58 |
scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
|
| 59 |
self.register_modules(transformer=transformer, scheduler=scheduler)
|
|
|
|
| 60 |
self._id2label = self._normalize_id2label(id2label)
|
| 61 |
self.labels = self._build_label2id(self._id2label)
|
| 62 |
+
self._labels_loaded_from_model_index = bool(self._id2label)
|
| 63 |
+
|
| 64 |
+
def _ensure_labels_loaded(self) -> None:
|
| 65 |
+
if self._labels_loaded_from_model_index:
|
| 66 |
+
return
|
| 67 |
+
loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
|
| 68 |
+
if loaded:
|
| 69 |
+
self._id2label = loaded
|
| 70 |
+
self.labels = self._build_label2id(self._id2label)
|
| 71 |
+
self._labels_loaded_from_model_index = True
|
| 72 |
|
| 73 |
@staticmethod
|
| 74 |
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
|
|
|
|
| 77 |
return {int(key): value for key, value in id2label.items()}
|
| 78 |
|
| 79 |
@staticmethod
|
| 80 |
+
def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
|
| 81 |
+
if not variant_path:
|
| 82 |
+
return {}
|
| 83 |
+
variant_dir = Path(variant_path).resolve()
|
| 84 |
+
model_index_path = variant_dir / "model_index.json"
|
| 85 |
if not model_index_path.exists():
|
| 86 |
return {}
|
| 87 |
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
|
|
|
|
| 102 |
|
| 103 |
@property
|
| 104 |
def id2label(self) -> Dict[int, str]:
|
| 105 |
+
self._ensure_labels_loaded()
|
| 106 |
return self._id2label
|
| 107 |
|
| 108 |
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
|
| 109 |
+
self._ensure_labels_loaded()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
label2id = self.labels
|
| 111 |
if not label2id:
|
| 112 |
+
raise ValueError(
|
| 113 |
+
"No English labels loaded. Ensure `id2label` exists in model_index.json."
|
| 114 |
+
)
|
| 115 |
|
| 116 |
if isinstance(label, str):
|
| 117 |
label = [label]
|
|
|
|
| 119 |
missing = [item for item in label if item not in label2id]
|
| 120 |
if missing:
|
| 121 |
preview = ", ".join(list(label2id.keys())[:8])
|
| 122 |
+
raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
|
|
|
|
|
|
|
| 123 |
return [label2id[item] for item in label]
|
| 124 |
|
| 125 |
def _normalize_class_labels(
|
|
|
|
| 154 |
output_type: Optional[str] = "pil",
|
| 155 |
return_dict: bool = True,
|
| 156 |
) -> Union[ImagePipelineOutput, Tuple]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
if num_inference_steps < 2:
|
| 158 |
raise ValueError("num_inference_steps must be >= 2.")
|
| 159 |
+
if output_type not in {"pil", "np", "pt"}:
|
| 160 |
+
raise ValueError("output_type must be one of: 'pil', 'np', 'pt'.")
|
| 161 |
|
| 162 |
class_label_ids = self._normalize_class_labels(class_labels)
|
| 163 |
do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
|
|
|
|
| 174 |
f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
|
| 175 |
)
|
| 176 |
channels = int(self.transformer.config.in_channels)
|
| 177 |
+
null_class_val = int(
|
| 178 |
+
getattr(self.transformer.config, "num_classes", getattr(self.transformer.config, "num_class_embeds", 1000))
|
| 179 |
+
)
|
| 180 |
|
| 181 |
if guidance_scale is None:
|
| 182 |
guidance_scale = 1.0
|
| 183 |
if noise_scale is None:
|
| 184 |
noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
|
| 185 |
|
| 186 |
+
latents = randn_tensor(
|
|
|
|
| 187 |
shape=(batch_size, channels, height, width),
|
| 188 |
+
generator=generator,
|
| 189 |
+
device=self._execution_device,
|
| 190 |
+
dtype=self.transformer.dtype,
|
| 191 |
+
) * noise_scale
|
|
|
|
|
|
|
| 192 |
|
| 193 |
class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
|
| 194 |
class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
|
|
|
|
| 200 |
class_labels_input = class_labels_t
|
| 201 |
|
| 202 |
self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
|
| 203 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator=generator)
|
| 204 |
for t in self.progress_bar(self.scheduler.timesteps):
|
| 205 |
step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
|
| 206 |
sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
|
|
|
|
| 235 |
sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
|
| 236 |
# JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
|
| 237 |
model_output = -(x_pred - latents) / sigma
|
| 238 |
+
latents = self.scheduler.step(model_output, t, latents, **extra_step_kwargs).prev_sample
|
| 239 |
|
| 240 |
images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
|
| 241 |
if output_type == "pt":
|
|
|
|
| 250 |
if not return_dict:
|
| 251 |
return (images,)
|
| 252 |
return ImagePipelineOutput(images=images)
|
| 253 |
+
|
| 254 |
+
JiTPipelineOutput = ImagePipelineOutput
|
JiT-L-16/pipeline.py
CHANGED
|
@@ -1,36 +1,24 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
import importlib
|
| 16 |
import json
|
| 17 |
-
import sys
|
| 18 |
from pathlib import Path
|
| 19 |
-
from typing import Dict, List, Optional, Tuple, Union
|
| 20 |
|
| 21 |
import torch
|
| 22 |
-
|
| 23 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 24 |
-
from diffusers.schedulers import FlowMatchHeunDiscreteScheduler, KarrasDiffusionSchedulers
|
| 25 |
from diffusers.utils.torch_utils import randn_tensor
|
| 26 |
|
| 27 |
-
|
| 28 |
RECOMMENDED_NOISE_BY_SIZE = {
|
| 29 |
256: 1.0,
|
| 30 |
512: 2.0,
|
| 31 |
}
|
| 32 |
|
| 33 |
-
|
| 34 |
class JiTPipeline(DiffusionPipeline):
|
| 35 |
r"""
|
| 36 |
Pipeline for image generation using JiT (Just image Transformer).
|
|
@@ -44,100 +32,43 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 44 |
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
|
| 45 |
"""
|
| 46 |
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
-
def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
|
| 51 |
-
"""Load a self-contained variant folder locally or from the Hub.
|
| 52 |
-
|
| 53 |
-
Examples:
|
| 54 |
-
JiTPipeline.from_pretrained(".")
|
| 55 |
-
JiTPipeline.from_pretrained("./JiT-H-32")
|
| 56 |
-
DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
|
| 57 |
-
"""
|
| 58 |
-
repo_root = Path(__file__).resolve().parent
|
| 59 |
-
|
| 60 |
-
if pretrained_model_name_or_path in (None, "", "."):
|
| 61 |
-
variant = repo_root
|
| 62 |
-
elif (
|
| 63 |
-
isinstance(pretrained_model_name_or_path, str)
|
| 64 |
-
and "/" in pretrained_model_name_or_path
|
| 65 |
-
and not Path(pretrained_model_name_or_path).exists()
|
| 66 |
-
):
|
| 67 |
-
from huggingface_hub import snapshot_download
|
| 68 |
-
|
| 69 |
-
hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
|
| 70 |
-
if subfolder:
|
| 71 |
-
hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
|
| 72 |
-
cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
|
| 73 |
-
variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
|
| 74 |
-
else:
|
| 75 |
-
variant = Path(pretrained_model_name_or_path)
|
| 76 |
-
if not variant.is_absolute():
|
| 77 |
-
candidate = (Path.cwd() / variant).resolve()
|
| 78 |
-
variant = candidate if candidate.exists() else (repo_root / variant).resolve()
|
| 79 |
-
if subfolder:
|
| 80 |
-
variant = variant / subfolder
|
| 81 |
-
|
| 82 |
-
id2label_override = kwargs.pop("id2label", None)
|
| 83 |
-
model_kwargs = dict(kwargs)
|
| 84 |
-
inserted: List[str] = []
|
| 85 |
-
|
| 86 |
-
def _load_component(folder: str, module_name: str, class_name: str):
|
| 87 |
-
comp_dir = variant / folder
|
| 88 |
-
module_path = comp_dir / f"{module_name}.py"
|
| 89 |
-
has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
|
| 90 |
-
if not module_path.exists() or not has_weights:
|
| 91 |
-
return None
|
| 92 |
-
|
| 93 |
-
comp_path = str(comp_dir)
|
| 94 |
-
if comp_path not in sys.path:
|
| 95 |
-
sys.path.insert(0, comp_path)
|
| 96 |
-
inserted.append(comp_path)
|
| 97 |
-
|
| 98 |
-
module = importlib.import_module(module_name)
|
| 99 |
-
component_cls = getattr(module, class_name)
|
| 100 |
-
return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
|
| 101 |
-
|
| 102 |
-
try:
|
| 103 |
-
transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
|
| 104 |
-
try:
|
| 105 |
-
scheduler = FlowMatchHeunDiscreteScheduler.from_pretrained(str(variant), subfolder="scheduler")
|
| 106 |
-
except Exception:
|
| 107 |
-
scheduler = FlowMatchHeunDiscreteScheduler(shift=4.0)
|
| 108 |
-
|
| 109 |
-
if transformer is None:
|
| 110 |
-
raise ValueError(f"No loadable transformer found under {variant}")
|
| 111 |
-
|
| 112 |
-
variant_path = str(variant)
|
| 113 |
-
model_index_path = variant / "model_index.json"
|
| 114 |
-
id2label = id2label_override or cls._read_id2label_from_model_index(model_index_path)
|
| 115 |
-
|
| 116 |
-
pipe = cls(
|
| 117 |
-
transformer=transformer,
|
| 118 |
-
scheduler=scheduler,
|
| 119 |
-
id2label=id2label,
|
| 120 |
-
)
|
| 121 |
-
if variant_path and hasattr(pipe, "register_to_config"):
|
| 122 |
-
pipe.register_to_config(_name_or_path=variant_path)
|
| 123 |
-
return pipe
|
| 124 |
-
finally:
|
| 125 |
-
for comp_path in inserted:
|
| 126 |
-
if comp_path in sys.path:
|
| 127 |
-
sys.path.remove(comp_path)
|
| 128 |
|
| 129 |
def __init__(
|
| 130 |
self,
|
| 131 |
transformer,
|
| 132 |
-
scheduler
|
| 133 |
id2label: Optional[Dict[Union[int, str], str]] = None,
|
| 134 |
):
|
| 135 |
super().__init__()
|
| 136 |
scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
|
| 137 |
self.register_modules(transformer=transformer, scheduler=scheduler)
|
| 138 |
-
|
| 139 |
self._id2label = self._normalize_id2label(id2label)
|
| 140 |
self.labels = self._build_label2id(self._id2label)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
@staticmethod
|
| 143 |
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
|
|
@@ -146,7 +77,11 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 146 |
return {int(key): value for key, value in id2label.items()}
|
| 147 |
|
| 148 |
@staticmethod
|
| 149 |
-
def _read_id2label_from_model_index(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
if not model_index_path.exists():
|
| 151 |
return {}
|
| 152 |
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
|
|
@@ -167,20 +102,16 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 167 |
|
| 168 |
@property
|
| 169 |
def id2label(self) -> Dict[int, str]:
|
| 170 |
-
|
| 171 |
return self._id2label
|
| 172 |
|
| 173 |
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
|
| 174 |
-
|
| 175 |
-
Map ImageNet label strings to class ids.
|
| 176 |
-
|
| 177 |
-
Args:
|
| 178 |
-
label (`str` or `list[str]`):
|
| 179 |
-
One or more English label strings. Each string must match a synonym in `id2label`.
|
| 180 |
-
"""
|
| 181 |
label2id = self.labels
|
| 182 |
if not label2id:
|
| 183 |
-
raise ValueError(
|
|
|
|
|
|
|
| 184 |
|
| 185 |
if isinstance(label, str):
|
| 186 |
label = [label]
|
|
@@ -188,9 +119,7 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 188 |
missing = [item for item in label if item not in label2id]
|
| 189 |
if missing:
|
| 190 |
preview = ", ".join(list(label2id.keys())[:8])
|
| 191 |
-
raise ValueError(
|
| 192 |
-
f"Unknown English label(s): {missing}. Example valid labels: {preview}, ..."
|
| 193 |
-
)
|
| 194 |
return [label2id[item] for item in label]
|
| 195 |
|
| 196 |
def _normalize_class_labels(
|
|
@@ -225,33 +154,10 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 225 |
output_type: Optional[str] = "pil",
|
| 226 |
return_dict: bool = True,
|
| 227 |
) -> Union[ImagePipelineOutput, Tuple]:
|
| 228 |
-
r"""
|
| 229 |
-
Generate class-conditional images.
|
| 230 |
-
|
| 231 |
-
Args:
|
| 232 |
-
class_labels (`int`, `str`, `list[int]`, or `list[str]`):
|
| 233 |
-
ImageNet class indices or human-readable English label strings.
|
| 234 |
-
guidance_scale (`float`, *optional*):
|
| 235 |
-
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
|
| 236 |
-
guidance_interval_min (`float`, defaults to `0.1`):
|
| 237 |
-
Lower bound of the CFG interval in flow time `t in [0, 1]`.
|
| 238 |
-
guidance_interval_max (`float`, defaults to `1.0`):
|
| 239 |
-
Upper bound of the CFG interval in flow time.
|
| 240 |
-
noise_scale (`float`, *optional*):
|
| 241 |
-
Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
|
| 242 |
-
t_eps (`float`, defaults to `5e-2`):
|
| 243 |
-
Epsilon clamp for the `1 - t` denominator, matching JiT source defaults.
|
| 244 |
-
generator (`torch.Generator`, *optional*):
|
| 245 |
-
RNG for reproducibility.
|
| 246 |
-
num_inference_steps (`int`, defaults to `50`):
|
| 247 |
-
Number of solver steps (at least 2).
|
| 248 |
-
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 249 |
-
`"pil"`, `"np"`, or `"pt"`.
|
| 250 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
| 251 |
-
Return [`ImagePipelineOutput`] if True.
|
| 252 |
-
"""
|
| 253 |
if num_inference_steps < 2:
|
| 254 |
raise ValueError("num_inference_steps must be >= 2.")
|
|
|
|
|
|
|
| 255 |
|
| 256 |
class_label_ids = self._normalize_class_labels(class_labels)
|
| 257 |
do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
|
|
@@ -268,22 +174,21 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 268 |
f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
|
| 269 |
)
|
| 270 |
channels = int(self.transformer.config.in_channels)
|
| 271 |
-
null_class_val = int(
|
|
|
|
|
|
|
| 272 |
|
| 273 |
if guidance_scale is None:
|
| 274 |
guidance_scale = 1.0
|
| 275 |
if noise_scale is None:
|
| 276 |
noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
|
| 277 |
|
| 278 |
-
latents = (
|
| 279 |
-
randn_tensor(
|
| 280 |
shape=(batch_size, channels, height, width),
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
* noise_scale
|
| 286 |
-
)
|
| 287 |
|
| 288 |
class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
|
| 289 |
class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
|
|
@@ -295,6 +200,7 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 295 |
class_labels_input = class_labels_t
|
| 296 |
|
| 297 |
self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
|
|
|
|
| 298 |
for t in self.progress_bar(self.scheduler.timesteps):
|
| 299 |
step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
|
| 300 |
sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
|
|
@@ -329,7 +235,7 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 329 |
sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
|
| 330 |
# JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
|
| 331 |
model_output = -(x_pred - latents) / sigma
|
| 332 |
-
latents = self.scheduler.step(model_output, t, latents).prev_sample
|
| 333 |
|
| 334 |
images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
|
| 335 |
if output_type == "pt":
|
|
@@ -344,3 +250,5 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 344 |
if not return_dict:
|
| 345 |
return (images,)
|
| 346 |
return ImagePipelineOutput(images=images)
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hub custom pipeline: JiTPipeline.
|
| 2 |
+
Load with native Hugging Face diffusers and trust_remote_code=True.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import inspect
|
| 8 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import json
|
|
|
|
| 10 |
from pathlib import Path
|
| 11 |
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
| 12 |
|
| 13 |
import torch
|
|
|
|
| 14 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
|
|
| 15 |
from diffusers.utils.torch_utils import randn_tensor
|
| 16 |
|
|
|
|
| 17 |
RECOMMENDED_NOISE_BY_SIZE = {
|
| 18 |
256: 1.0,
|
| 19 |
512: 2.0,
|
| 20 |
}
|
| 21 |
|
|
|
|
| 22 |
class JiTPipeline(DiffusionPipeline):
|
| 23 |
r"""
|
| 24 |
Pipeline for image generation using JiT (Just image Transformer).
|
|
|
|
| 32 |
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
|
| 33 |
"""
|
| 34 |
|
| 35 |
+
@staticmethod
|
| 36 |
+
def prepare_extra_step_kwargs(
|
| 37 |
+
scheduler,
|
| 38 |
+
generator=None,
|
| 39 |
+
eta: float | None = None,
|
| 40 |
+
):
|
| 41 |
+
kwargs = {}
|
| 42 |
+
step_params = set(inspect.signature(scheduler.step).parameters.keys())
|
| 43 |
+
if "generator" in step_params:
|
| 44 |
+
kwargs["generator"] = generator
|
| 45 |
+
if eta is not None and "eta" in step_params:
|
| 46 |
+
kwargs["eta"] = eta
|
| 47 |
+
return kwargs
|
| 48 |
|
| 49 |
+
model_cpu_offload_seq = "transformer"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
def __init__(
|
| 52 |
self,
|
| 53 |
transformer,
|
| 54 |
+
scheduler,
|
| 55 |
id2label: Optional[Dict[Union[int, str], str]] = None,
|
| 56 |
):
|
| 57 |
super().__init__()
|
| 58 |
scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
|
| 59 |
self.register_modules(transformer=transformer, scheduler=scheduler)
|
|
|
|
| 60 |
self._id2label = self._normalize_id2label(id2label)
|
| 61 |
self.labels = self._build_label2id(self._id2label)
|
| 62 |
+
self._labels_loaded_from_model_index = bool(self._id2label)
|
| 63 |
+
|
| 64 |
+
def _ensure_labels_loaded(self) -> None:
|
| 65 |
+
if self._labels_loaded_from_model_index:
|
| 66 |
+
return
|
| 67 |
+
loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
|
| 68 |
+
if loaded:
|
| 69 |
+
self._id2label = loaded
|
| 70 |
+
self.labels = self._build_label2id(self._id2label)
|
| 71 |
+
self._labels_loaded_from_model_index = True
|
| 72 |
|
| 73 |
@staticmethod
|
| 74 |
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
|
|
|
|
| 77 |
return {int(key): value for key, value in id2label.items()}
|
| 78 |
|
| 79 |
@staticmethod
|
| 80 |
+
def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
|
| 81 |
+
if not variant_path:
|
| 82 |
+
return {}
|
| 83 |
+
variant_dir = Path(variant_path).resolve()
|
| 84 |
+
model_index_path = variant_dir / "model_index.json"
|
| 85 |
if not model_index_path.exists():
|
| 86 |
return {}
|
| 87 |
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
|
|
|
|
| 102 |
|
| 103 |
@property
|
| 104 |
def id2label(self) -> Dict[int, str]:
|
| 105 |
+
self._ensure_labels_loaded()
|
| 106 |
return self._id2label
|
| 107 |
|
| 108 |
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
|
| 109 |
+
self._ensure_labels_loaded()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
label2id = self.labels
|
| 111 |
if not label2id:
|
| 112 |
+
raise ValueError(
|
| 113 |
+
"No English labels loaded. Ensure `id2label` exists in model_index.json."
|
| 114 |
+
)
|
| 115 |
|
| 116 |
if isinstance(label, str):
|
| 117 |
label = [label]
|
|
|
|
| 119 |
missing = [item for item in label if item not in label2id]
|
| 120 |
if missing:
|
| 121 |
preview = ", ".join(list(label2id.keys())[:8])
|
| 122 |
+
raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
|
|
|
|
|
|
|
| 123 |
return [label2id[item] for item in label]
|
| 124 |
|
| 125 |
def _normalize_class_labels(
|
|
|
|
| 154 |
output_type: Optional[str] = "pil",
|
| 155 |
return_dict: bool = True,
|
| 156 |
) -> Union[ImagePipelineOutput, Tuple]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
if num_inference_steps < 2:
|
| 158 |
raise ValueError("num_inference_steps must be >= 2.")
|
| 159 |
+
if output_type not in {"pil", "np", "pt"}:
|
| 160 |
+
raise ValueError("output_type must be one of: 'pil', 'np', 'pt'.")
|
| 161 |
|
| 162 |
class_label_ids = self._normalize_class_labels(class_labels)
|
| 163 |
do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
|
|
|
|
| 174 |
f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
|
| 175 |
)
|
| 176 |
channels = int(self.transformer.config.in_channels)
|
| 177 |
+
null_class_val = int(
|
| 178 |
+
getattr(self.transformer.config, "num_classes", getattr(self.transformer.config, "num_class_embeds", 1000))
|
| 179 |
+
)
|
| 180 |
|
| 181 |
if guidance_scale is None:
|
| 182 |
guidance_scale = 1.0
|
| 183 |
if noise_scale is None:
|
| 184 |
noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
|
| 185 |
|
| 186 |
+
latents = randn_tensor(
|
|
|
|
| 187 |
shape=(batch_size, channels, height, width),
|
| 188 |
+
generator=generator,
|
| 189 |
+
device=self._execution_device,
|
| 190 |
+
dtype=self.transformer.dtype,
|
| 191 |
+
) * noise_scale
|
|
|
|
|
|
|
| 192 |
|
| 193 |
class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
|
| 194 |
class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
|
|
|
|
| 200 |
class_labels_input = class_labels_t
|
| 201 |
|
| 202 |
self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
|
| 203 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator=generator)
|
| 204 |
for t in self.progress_bar(self.scheduler.timesteps):
|
| 205 |
step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
|
| 206 |
sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
|
|
|
|
| 235 |
sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
|
| 236 |
# JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
|
| 237 |
model_output = -(x_pred - latents) / sigma
|
| 238 |
+
latents = self.scheduler.step(model_output, t, latents, **extra_step_kwargs).prev_sample
|
| 239 |
|
| 240 |
images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
|
| 241 |
if output_type == "pt":
|
|
|
|
| 250 |
if not return_dict:
|
| 251 |
return (images,)
|
| 252 |
return ImagePipelineOutput(images=images)
|
| 253 |
+
|
| 254 |
+
JiTPipelineOutput = ImagePipelineOutput
|
JiT-L-32/pipeline.py
CHANGED
|
@@ -1,36 +1,24 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
import importlib
|
| 16 |
import json
|
| 17 |
-
import sys
|
| 18 |
from pathlib import Path
|
| 19 |
-
from typing import Dict, List, Optional, Tuple, Union
|
| 20 |
|
| 21 |
import torch
|
| 22 |
-
|
| 23 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 24 |
-
from diffusers.schedulers import FlowMatchHeunDiscreteScheduler, KarrasDiffusionSchedulers
|
| 25 |
from diffusers.utils.torch_utils import randn_tensor
|
| 26 |
|
| 27 |
-
|
| 28 |
RECOMMENDED_NOISE_BY_SIZE = {
|
| 29 |
256: 1.0,
|
| 30 |
512: 2.0,
|
| 31 |
}
|
| 32 |
|
| 33 |
-
|
| 34 |
class JiTPipeline(DiffusionPipeline):
|
| 35 |
r"""
|
| 36 |
Pipeline for image generation using JiT (Just image Transformer).
|
|
@@ -44,100 +32,43 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 44 |
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
|
| 45 |
"""
|
| 46 |
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
-
def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
|
| 51 |
-
"""Load a self-contained variant folder locally or from the Hub.
|
| 52 |
-
|
| 53 |
-
Examples:
|
| 54 |
-
JiTPipeline.from_pretrained(".")
|
| 55 |
-
JiTPipeline.from_pretrained("./JiT-H-32")
|
| 56 |
-
DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
|
| 57 |
-
"""
|
| 58 |
-
repo_root = Path(__file__).resolve().parent
|
| 59 |
-
|
| 60 |
-
if pretrained_model_name_or_path in (None, "", "."):
|
| 61 |
-
variant = repo_root
|
| 62 |
-
elif (
|
| 63 |
-
isinstance(pretrained_model_name_or_path, str)
|
| 64 |
-
and "/" in pretrained_model_name_or_path
|
| 65 |
-
and not Path(pretrained_model_name_or_path).exists()
|
| 66 |
-
):
|
| 67 |
-
from huggingface_hub import snapshot_download
|
| 68 |
-
|
| 69 |
-
hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
|
| 70 |
-
if subfolder:
|
| 71 |
-
hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
|
| 72 |
-
cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
|
| 73 |
-
variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
|
| 74 |
-
else:
|
| 75 |
-
variant = Path(pretrained_model_name_or_path)
|
| 76 |
-
if not variant.is_absolute():
|
| 77 |
-
candidate = (Path.cwd() / variant).resolve()
|
| 78 |
-
variant = candidate if candidate.exists() else (repo_root / variant).resolve()
|
| 79 |
-
if subfolder:
|
| 80 |
-
variant = variant / subfolder
|
| 81 |
-
|
| 82 |
-
id2label_override = kwargs.pop("id2label", None)
|
| 83 |
-
model_kwargs = dict(kwargs)
|
| 84 |
-
inserted: List[str] = []
|
| 85 |
-
|
| 86 |
-
def _load_component(folder: str, module_name: str, class_name: str):
|
| 87 |
-
comp_dir = variant / folder
|
| 88 |
-
module_path = comp_dir / f"{module_name}.py"
|
| 89 |
-
has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
|
| 90 |
-
if not module_path.exists() or not has_weights:
|
| 91 |
-
return None
|
| 92 |
-
|
| 93 |
-
comp_path = str(comp_dir)
|
| 94 |
-
if comp_path not in sys.path:
|
| 95 |
-
sys.path.insert(0, comp_path)
|
| 96 |
-
inserted.append(comp_path)
|
| 97 |
-
|
| 98 |
-
module = importlib.import_module(module_name)
|
| 99 |
-
component_cls = getattr(module, class_name)
|
| 100 |
-
return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
|
| 101 |
-
|
| 102 |
-
try:
|
| 103 |
-
transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
|
| 104 |
-
try:
|
| 105 |
-
scheduler = FlowMatchHeunDiscreteScheduler.from_pretrained(str(variant), subfolder="scheduler")
|
| 106 |
-
except Exception:
|
| 107 |
-
scheduler = FlowMatchHeunDiscreteScheduler(shift=4.0)
|
| 108 |
-
|
| 109 |
-
if transformer is None:
|
| 110 |
-
raise ValueError(f"No loadable transformer found under {variant}")
|
| 111 |
-
|
| 112 |
-
variant_path = str(variant)
|
| 113 |
-
model_index_path = variant / "model_index.json"
|
| 114 |
-
id2label = id2label_override or cls._read_id2label_from_model_index(model_index_path)
|
| 115 |
-
|
| 116 |
-
pipe = cls(
|
| 117 |
-
transformer=transformer,
|
| 118 |
-
scheduler=scheduler,
|
| 119 |
-
id2label=id2label,
|
| 120 |
-
)
|
| 121 |
-
if variant_path and hasattr(pipe, "register_to_config"):
|
| 122 |
-
pipe.register_to_config(_name_or_path=variant_path)
|
| 123 |
-
return pipe
|
| 124 |
-
finally:
|
| 125 |
-
for comp_path in inserted:
|
| 126 |
-
if comp_path in sys.path:
|
| 127 |
-
sys.path.remove(comp_path)
|
| 128 |
|
| 129 |
def __init__(
|
| 130 |
self,
|
| 131 |
transformer,
|
| 132 |
-
scheduler
|
| 133 |
id2label: Optional[Dict[Union[int, str], str]] = None,
|
| 134 |
):
|
| 135 |
super().__init__()
|
| 136 |
scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
|
| 137 |
self.register_modules(transformer=transformer, scheduler=scheduler)
|
| 138 |
-
|
| 139 |
self._id2label = self._normalize_id2label(id2label)
|
| 140 |
self.labels = self._build_label2id(self._id2label)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
@staticmethod
|
| 143 |
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
|
|
@@ -146,7 +77,11 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 146 |
return {int(key): value for key, value in id2label.items()}
|
| 147 |
|
| 148 |
@staticmethod
|
| 149 |
-
def _read_id2label_from_model_index(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
if not model_index_path.exists():
|
| 151 |
return {}
|
| 152 |
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
|
|
@@ -167,20 +102,16 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 167 |
|
| 168 |
@property
|
| 169 |
def id2label(self) -> Dict[int, str]:
|
| 170 |
-
|
| 171 |
return self._id2label
|
| 172 |
|
| 173 |
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
|
| 174 |
-
|
| 175 |
-
Map ImageNet label strings to class ids.
|
| 176 |
-
|
| 177 |
-
Args:
|
| 178 |
-
label (`str` or `list[str]`):
|
| 179 |
-
One or more English label strings. Each string must match a synonym in `id2label`.
|
| 180 |
-
"""
|
| 181 |
label2id = self.labels
|
| 182 |
if not label2id:
|
| 183 |
-
raise ValueError(
|
|
|
|
|
|
|
| 184 |
|
| 185 |
if isinstance(label, str):
|
| 186 |
label = [label]
|
|
@@ -188,9 +119,7 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 188 |
missing = [item for item in label if item not in label2id]
|
| 189 |
if missing:
|
| 190 |
preview = ", ".join(list(label2id.keys())[:8])
|
| 191 |
-
raise ValueError(
|
| 192 |
-
f"Unknown English label(s): {missing}. Example valid labels: {preview}, ..."
|
| 193 |
-
)
|
| 194 |
return [label2id[item] for item in label]
|
| 195 |
|
| 196 |
def _normalize_class_labels(
|
|
@@ -225,33 +154,10 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 225 |
output_type: Optional[str] = "pil",
|
| 226 |
return_dict: bool = True,
|
| 227 |
) -> Union[ImagePipelineOutput, Tuple]:
|
| 228 |
-
r"""
|
| 229 |
-
Generate class-conditional images.
|
| 230 |
-
|
| 231 |
-
Args:
|
| 232 |
-
class_labels (`int`, `str`, `list[int]`, or `list[str]`):
|
| 233 |
-
ImageNet class indices or human-readable English label strings.
|
| 234 |
-
guidance_scale (`float`, *optional*):
|
| 235 |
-
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
|
| 236 |
-
guidance_interval_min (`float`, defaults to `0.1`):
|
| 237 |
-
Lower bound of the CFG interval in flow time `t in [0, 1]`.
|
| 238 |
-
guidance_interval_max (`float`, defaults to `1.0`):
|
| 239 |
-
Upper bound of the CFG interval in flow time.
|
| 240 |
-
noise_scale (`float`, *optional*):
|
| 241 |
-
Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
|
| 242 |
-
t_eps (`float`, defaults to `5e-2`):
|
| 243 |
-
Epsilon clamp for the `1 - t` denominator, matching JiT source defaults.
|
| 244 |
-
generator (`torch.Generator`, *optional*):
|
| 245 |
-
RNG for reproducibility.
|
| 246 |
-
num_inference_steps (`int`, defaults to `50`):
|
| 247 |
-
Number of solver steps (at least 2).
|
| 248 |
-
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 249 |
-
`"pil"`, `"np"`, or `"pt"`.
|
| 250 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
| 251 |
-
Return [`ImagePipelineOutput`] if True.
|
| 252 |
-
"""
|
| 253 |
if num_inference_steps < 2:
|
| 254 |
raise ValueError("num_inference_steps must be >= 2.")
|
|
|
|
|
|
|
| 255 |
|
| 256 |
class_label_ids = self._normalize_class_labels(class_labels)
|
| 257 |
do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
|
|
@@ -268,22 +174,21 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 268 |
f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
|
| 269 |
)
|
| 270 |
channels = int(self.transformer.config.in_channels)
|
| 271 |
-
null_class_val = int(
|
|
|
|
|
|
|
| 272 |
|
| 273 |
if guidance_scale is None:
|
| 274 |
guidance_scale = 1.0
|
| 275 |
if noise_scale is None:
|
| 276 |
noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
|
| 277 |
|
| 278 |
-
latents = (
|
| 279 |
-
randn_tensor(
|
| 280 |
shape=(batch_size, channels, height, width),
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
* noise_scale
|
| 286 |
-
)
|
| 287 |
|
| 288 |
class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
|
| 289 |
class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
|
|
@@ -295,6 +200,7 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 295 |
class_labels_input = class_labels_t
|
| 296 |
|
| 297 |
self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
|
|
|
|
| 298 |
for t in self.progress_bar(self.scheduler.timesteps):
|
| 299 |
step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
|
| 300 |
sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
|
|
@@ -329,7 +235,7 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 329 |
sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
|
| 330 |
# JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
|
| 331 |
model_output = -(x_pred - latents) / sigma
|
| 332 |
-
latents = self.scheduler.step(model_output, t, latents).prev_sample
|
| 333 |
|
| 334 |
images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
|
| 335 |
if output_type == "pt":
|
|
@@ -344,3 +250,5 @@ class JiTPipeline(DiffusionPipeline):
|
|
| 344 |
if not return_dict:
|
| 345 |
return (images,)
|
| 346 |
return ImagePipelineOutput(images=images)
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hub custom pipeline: JiTPipeline.
|
| 2 |
+
Load with native Hugging Face diffusers and trust_remote_code=True.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import inspect
|
| 8 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import json
|
|
|
|
| 10 |
from pathlib import Path
|
| 11 |
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
| 12 |
|
| 13 |
import torch
|
|
|
|
| 14 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
|
|
| 15 |
from diffusers.utils.torch_utils import randn_tensor
|
| 16 |
|
|
|
|
| 17 |
RECOMMENDED_NOISE_BY_SIZE = {
|
| 18 |
256: 1.0,
|
| 19 |
512: 2.0,
|
| 20 |
}
|
| 21 |
|
|
|
|
| 22 |
class JiTPipeline(DiffusionPipeline):
|
| 23 |
r"""
|
| 24 |
Pipeline for image generation using JiT (Just image Transformer).
|
|
|
|
| 32 |
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
|
| 33 |
"""
|
| 34 |
|
| 35 |
+
@staticmethod
|
| 36 |
+
def prepare_extra_step_kwargs(
|
| 37 |
+
scheduler,
|
| 38 |
+
generator=None,
|
| 39 |
+
eta: float | None = None,
|
| 40 |
+
):
|
| 41 |
+
kwargs = {}
|
| 42 |
+
step_params = set(inspect.signature(scheduler.step).parameters.keys())
|
| 43 |
+
if "generator" in step_params:
|
| 44 |
+
kwargs["generator"] = generator
|
| 45 |
+
if eta is not None and "eta" in step_params:
|
| 46 |
+
kwargs["eta"] = eta
|
| 47 |
+
return kwargs
|
| 48 |
|
| 49 |
+
model_cpu_offload_seq = "transformer"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
def __init__(
|
| 52 |
self,
|
| 53 |
transformer,
|
| 54 |
+
scheduler,
|
| 55 |
id2label: Optional[Dict[Union[int, str], str]] = None,
|
| 56 |
):
|
| 57 |
super().__init__()
|
| 58 |
scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
|
| 59 |
self.register_modules(transformer=transformer, scheduler=scheduler)
|
|
|
|
| 60 |
self._id2label = self._normalize_id2label(id2label)
|
| 61 |
self.labels = self._build_label2id(self._id2label)
|
| 62 |
+
self._labels_loaded_from_model_index = bool(self._id2label)
|
| 63 |
+
|
| 64 |
+
def _ensure_labels_loaded(self) -> None:
|
| 65 |
+
if self._labels_loaded_from_model_index:
|
| 66 |
+
return
|
| 67 |
+
loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
|
| 68 |
+
if loaded:
|
| 69 |
+
self._id2label = loaded
|
| 70 |
+
self.labels = self._build_label2id(self._id2label)
|
| 71 |
+
self._labels_loaded_from_model_index = True
|
| 72 |
|
| 73 |
@staticmethod
|
| 74 |
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
|
|
|
|
| 77 |
return {int(key): value for key, value in id2label.items()}
|
| 78 |
|
| 79 |
@staticmethod
|
| 80 |
+
def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
|
| 81 |
+
if not variant_path:
|
| 82 |
+
return {}
|
| 83 |
+
variant_dir = Path(variant_path).resolve()
|
| 84 |
+
model_index_path = variant_dir / "model_index.json"
|
| 85 |
if not model_index_path.exists():
|
| 86 |
return {}
|
| 87 |
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
|
|
|
|
| 102 |
|
| 103 |
@property
|
| 104 |
def id2label(self) -> Dict[int, str]:
|
| 105 |
+
self._ensure_labels_loaded()
|
| 106 |
return self._id2label
|
| 107 |
|
| 108 |
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
|
| 109 |
+
self._ensure_labels_loaded()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
label2id = self.labels
|
| 111 |
if not label2id:
|
| 112 |
+
raise ValueError(
|
| 113 |
+
"No English labels loaded. Ensure `id2label` exists in model_index.json."
|
| 114 |
+
)
|
| 115 |
|
| 116 |
if isinstance(label, str):
|
| 117 |
label = [label]
|
|
|
|
| 119 |
missing = [item for item in label if item not in label2id]
|
| 120 |
if missing:
|
| 121 |
preview = ", ".join(list(label2id.keys())[:8])
|
| 122 |
+
raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
|
|
|
|
|
|
|
| 123 |
return [label2id[item] for item in label]
|
| 124 |
|
| 125 |
def _normalize_class_labels(
|
|
|
|
| 154 |
output_type: Optional[str] = "pil",
|
| 155 |
return_dict: bool = True,
|
| 156 |
) -> Union[ImagePipelineOutput, Tuple]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
if num_inference_steps < 2:
|
| 158 |
raise ValueError("num_inference_steps must be >= 2.")
|
| 159 |
+
if output_type not in {"pil", "np", "pt"}:
|
| 160 |
+
raise ValueError("output_type must be one of: 'pil', 'np', 'pt'.")
|
| 161 |
|
| 162 |
class_label_ids = self._normalize_class_labels(class_labels)
|
| 163 |
do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
|
|
|
|
| 174 |
f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
|
| 175 |
)
|
| 176 |
channels = int(self.transformer.config.in_channels)
|
| 177 |
+
null_class_val = int(
|
| 178 |
+
getattr(self.transformer.config, "num_classes", getattr(self.transformer.config, "num_class_embeds", 1000))
|
| 179 |
+
)
|
| 180 |
|
| 181 |
if guidance_scale is None:
|
| 182 |
guidance_scale = 1.0
|
| 183 |
if noise_scale is None:
|
| 184 |
noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
|
| 185 |
|
| 186 |
+
latents = randn_tensor(
|
|
|
|
| 187 |
shape=(batch_size, channels, height, width),
|
| 188 |
+
generator=generator,
|
| 189 |
+
device=self._execution_device,
|
| 190 |
+
dtype=self.transformer.dtype,
|
| 191 |
+
) * noise_scale
|
|
|
|
|
|
|
| 192 |
|
| 193 |
class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
|
| 194 |
class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
|
|
|
|
| 200 |
class_labels_input = class_labels_t
|
| 201 |
|
| 202 |
self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
|
| 203 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator=generator)
|
| 204 |
for t in self.progress_bar(self.scheduler.timesteps):
|
| 205 |
step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
|
| 206 |
sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
|
|
|
|
| 235 |
sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
|
| 236 |
# JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
|
| 237 |
model_output = -(x_pred - latents) / sigma
|
| 238 |
+
latents = self.scheduler.step(model_output, t, latents, **extra_step_kwargs).prev_sample
|
| 239 |
|
| 240 |
images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
|
| 241 |
if output_type == "pt":
|
|
|
|
| 250 |
if not return_dict:
|
| 251 |
return (images,)
|
| 252 |
return ImagePipelineOutput(images=images)
|
| 253 |
+
|
| 254 |
+
JiTPipelineOutput = ImagePipelineOutput
|