BiliSakura's picture
Fix iMF/pMF inference by filtering unsupported pipeline kwargs
7e23f03 verified
Raw
History Blame Contribute Delete
13.4 kB
"""Notebook-style diffusers loader for BiliSakura Hub models."""
from __future__ import annotations
import gc
import inspect
import os
from pathlib import Path
from typing import Any, get_args, get_origin
import torch
from diffusers import DiffusionPipeline
import diffusers.pipelines.pipeline_utils as pipeline_utils
from huggingface_hub import snapshot_download
from model_catalog import (
ModelProfile,
get_profile,
scheduler_choices_for_profile,
uses_native_scheduler,
)
def _patch_diffusers_custom_pipeline_type_check() -> None:
"""Work around diffusers 0.36 KeyError when custom pipelines omit parsed annotations."""
if getattr(pipeline_utils, "_bilisakura_type_check_patch", False):
return
@classmethod
def patched_get_signature_types(cls):
signature_types = {}
for name, param in inspect.signature(cls.__init__).parameters.items():
if name == "self":
continue
annotation = param.annotation
if annotation is inspect.Parameter.empty:
signature_types[name] = (inspect.Signature.empty,)
continue
origin = get_origin(annotation)
if inspect.isclass(annotation):
signature_types[name] = (annotation,)
elif origin is not None:
args = get_args(annotation)
signature_types[name] = args if args else (annotation,)
else:
signature_types[name] = (inspect.Signature.empty,)
return signature_types
original_from_pretrained = DiffusionPipeline.from_pretrained.__func__
@classmethod
def from_pretrained_patched(cls, pretrained_model_name_or_path, *args, **kwargs):
original_get_signature_types = DiffusionPipeline._get_signature_types
DiffusionPipeline._get_signature_types = patched_get_signature_types
try:
return original_from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs)
finally:
DiffusionPipeline._get_signature_types = original_get_signature_types
DiffusionPipeline.from_pretrained = from_pretrained_patched
pipeline_utils._bilisakura_type_check_patch = True
_patch_diffusers_custom_pipeline_type_check()
LOCAL_MODELS_ROOT = Path(os.environ.get("LOCAL_MODELS_ROOT", "")).expanduser()
USE_LOCAL_MODELS = LOCAL_MODELS_ROOT.is_dir()
HF_ORG = os.environ.get("HF_MODEL_ORG", "BiliSakura")
class PipelineManager:
def __init__(self) -> None:
self._pipe: DiffusionPipeline | None = None
self._loaded_label: str | None = None
self._loaded_profile: ModelProfile | None = None
self._on_cuda: bool = False
@property
def loaded_label(self) -> str | None:
return self._loaded_label
@property
def loaded_profile(self) -> ModelProfile | None:
return self._loaded_profile
@property
def pipe(self) -> DiffusionPipeline | None:
return self._pipe
def _resolve_model_source(self, profile: ModelProfile) -> tuple[str, str]:
"""Return local checkpoint path and a human-readable source label."""
if USE_LOCAL_MODELS:
local_path = LOCAL_MODELS_ROOT / profile.collection / profile.variant
if local_path.is_dir():
return str(local_path.resolve()), str(local_path)
repo_id = profile.hub_repo
cached_repo = snapshot_download(
repo_id,
repo_type="model",
allow_patterns=[f"{profile.variant}/**"],
)
model_dir = Path(cached_repo) / profile.variant
return str(model_dir.resolve()), f"{repo_id} (subfolder `{profile.variant}`)"
def _resolve_dtype(self, profile: ModelProfile) -> torch.dtype:
return torch.bfloat16 if profile.dtype == "bfloat16" else torch.float32
def _apply_post_load(self, pipe: DiffusionPipeline) -> None:
pipe.set_progress_bar_config(disable=True)
def unload(self) -> None:
if self._pipe is not None:
del self._pipe
self._pipe = None
self._loaded_label = None
self._loaded_profile = None
self._on_cuda = False
gc.collect()
def move_to_cuda(self) -> None:
if self._pipe is None or self._on_cuda:
return
self._pipe = self._pipe.to("cuda")
self._on_cuda = True
def load(self, collection: str, variant: str) -> tuple[str, ModelProfile]:
profile = get_profile(collection, variant)
label = profile.label
if self._loaded_label == label and self._pipe is not None:
return f"Model already loaded: `{label}`", profile
self.unload()
model_source, source_label = self._resolve_model_source(profile)
dtype = self._resolve_dtype(profile)
model_path = Path(model_source)
load_kwargs: dict[str, Any] = {
"trust_remote_code": True,
"torch_dtype": dtype,
}
if profile.use_custom_pipeline:
load_kwargs["custom_pipeline"] = str(model_path / "pipeline.py")
if USE_LOCAL_MODELS and (LOCAL_MODELS_ROOT / profile.collection / profile.variant).is_dir():
load_kwargs["local_files_only"] = True
pipe = DiffusionPipeline.from_pretrained(model_source, **load_kwargs)
self._apply_post_load(pipe)
self._pipe = pipe
self._loaded_label = label
self._loaded_profile = profile
self._on_cuda = False
return f"Loaded `{label}` from `{source_label}`.", profile
PIPELINE_MANAGER = PipelineManager()
def current_scheduler_name(pipe: DiffusionPipeline) -> str:
return type(pipe.scheduler).__name__
def scheduler_options_for_profile(profile: ModelProfile, pipe: DiffusionPipeline | None = None) -> tuple[list[str], str]:
if uses_native_scheduler(profile):
if pipe is not None:
name = current_scheduler_name(pipe)
return [name], name
return ["checkpoint"], "checkpoint"
swappable = scheduler_choices_for_profile(profile)
if pipe is not None:
loaded = current_scheduler_name(pipe)
choices = list(swappable)
if loaded not in choices:
choices = [loaded, *choices]
return choices, loaded
return ["checkpoint", *swappable], "checkpoint"
def swap_scheduler(pipe: DiffusionPipeline, scheduler_name: str, profile: ModelProfile) -> None:
if scheduler_name in {"", "checkpoint"}:
return
current = current_scheduler_name(pipe)
if scheduler_name == current:
return
import diffusers
if not hasattr(diffusers, scheduler_name):
raise ValueError(f"Unknown diffusers scheduler: {scheduler_name}")
scheduler_cls = getattr(diffusers, scheduler_name)
extra_kwargs = profile.scheduler_kwargs if profile.scheduler == scheduler_name else {}
pipe.scheduler = scheduler_cls.from_config(pipe.scheduler.config, **extra_kwargs)
def _to_int(value: Any, *, default: int = 0) -> int:
if value is None:
return default
if isinstance(value, bool):
return int(value)
if isinstance(value, (int, float)):
return int(value)
text = str(value).strip()
if not text:
return default
return int(float(text))
def _to_float(value: Any, *, default: float = 0.0) -> float:
if value is None:
return default
if isinstance(value, (int, float)):
return float(value)
text = str(value).strip()
if not text:
return default
return float(text)
def build_inference_steps(profile: ModelProfile, steps: Any) -> int | list[int]:
steps = _to_int(steps, default=profile.default_steps)
if profile.steps_are_list:
per_stage = max(1, steps // 4)
return [per_stage, per_stage, per_stage, per_stage]
return steps
def _split_synonyms(label_text: str) -> list[str]:
return [part.strip() for part in label_text.split(",") if part.strip()]
def _get_pipe_id2label(pipe: DiffusionPipeline) -> dict[int, str]:
id2label: dict[int, str] | None = None
if hasattr(pipe, "id2label"):
raw = pipe.id2label
if isinstance(raw, dict) and raw:
id2label = {int(key): str(value) for key, value in raw.items()}
if id2label is None:
config = getattr(pipe, "config", None)
raw = getattr(config, "id2label", None) if config is not None else None
if isinstance(raw, dict) and raw:
id2label = {int(key): str(value) for key, value in raw.items()}
return id2label or {}
def _build_label2id(id2label: dict[int, str]) -> dict[str, int]:
label2id: dict[str, int] = {}
for class_id, value in id2label.items():
synonyms = _split_synonyms(value)
if not synonyms:
continue
for synonym in synonyms:
label2id[synonym] = int(class_id)
label2id[value.strip()] = int(class_id)
return label2id
def resolve_class_labels(
pipe: DiffusionPipeline,
class_label: str,
*,
default: str,
) -> int | str:
"""Resolve a class name or id using the loaded model's id2label synonyms."""
label = str(class_label or "").strip() or default
if label.isdigit():
return int(label)
if label.replace(".", "", 1).isdigit():
return int(float(label))
id2label = _get_pipe_id2label(pipe)
if not id2label:
return label
label2id = _build_label2id(id2label)
if label in label2id:
return label2id[label]
for part in _split_synonyms(label):
if part in label2id:
return label2id[part]
normalized = label.casefold()
for class_id, value in id2label.items():
if value.strip().casefold() == normalized:
return int(class_id)
for part in _split_synonyms(value):
if part.casefold() == normalized:
return int(class_id)
if hasattr(pipe, "get_label_ids"):
for candidate in _split_synonyms(label):
try:
return pipe.get_label_ids(candidate)[0]
except (ValueError, TypeError):
continue
return label
def primary_label_for_id(pipe: DiffusionPipeline, class_id: int, *, fallback: str) -> str:
"""Return the first synonym from id2label for a class id."""
id2label = _get_pipe_id2label(pipe)
value = id2label.get(int(class_id))
if not value:
return fallback
synonyms = _split_synonyms(value)
return synonyms[0] if synonyms else fallback
def default_class_label_for_pipe(pipe: DiffusionPipeline, profile: ModelProfile) -> str:
"""Pick a sensible default label using id2label synonyms when available."""
id2label = _get_pipe_id2label(pipe)
if not id2label:
return profile.default_class_label
preferred_ids = (207, 285, 281)
for class_id in preferred_ids:
if class_id in id2label:
return primary_label_for_id(pipe, class_id, fallback=profile.default_class_label)
for value in id2label.values():
for synonym in _split_synonyms(value):
if synonym.casefold() == profile.default_class_label.casefold():
return synonym
first_value = next(iter(id2label.values()), profile.default_class_label)
synonyms = _split_synonyms(first_value)
return synonyms[0] if synonyms else profile.default_class_label
def _filter_call_kwargs(pipe: DiffusionPipeline, call_kwargs: dict[str, Any]) -> dict[str, Any]:
"""Drop kwargs that the pipeline __call__ does not accept (e.g. height/width for iMF/pMF)."""
try:
params = inspect.signature(pipe.__call__).parameters
except (TypeError, ValueError):
return call_kwargs
if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in params.values()):
return call_kwargs
accepted = set(params.keys())
return {key: value for key, value in call_kwargs.items() if key in accepted}
def run_inference(
profile: ModelProfile,
pipe: DiffusionPipeline,
*,
class_label: str,
seed: int,
num_steps: int,
guidance_scale: float,
height: int,
width: int,
scheduler_name: str | None = None,
extra_kwargs: dict[str, Any] | None = None,
) -> Any:
seed = _to_int(seed, default=profile.default_seed)
num_steps = _to_int(num_steps, default=profile.default_steps)
guidance_scale = _to_float(guidance_scale, default=profile.default_guidance)
height = _to_int(height, default=0)
width = _to_int(width, default=0)
if scheduler_name and not uses_native_scheduler(profile):
swap_scheduler(pipe, str(scheduler_name), profile)
generator = torch.Generator(device="cuda").manual_seed(seed)
call_kwargs: dict[str, Any] = {
"num_inference_steps": build_inference_steps(profile, num_steps),
"guidance_scale": guidance_scale,
"generator": generator,
}
call_kwargs.update(extra_kwargs if extra_kwargs is not None else profile.extra_call_kwargs)
call_kwargs["class_labels"] = resolve_class_labels(
pipe,
class_label,
default=profile.default_class_label,
)
native = profile.infer_resolution()
if height > 0 and width > 0:
if profile.collection != "ADM-diffusers" or height != native or width != native:
call_kwargs["height"] = height
call_kwargs["width"] = width
return pipe(**_filter_call_kwargs(pipe, call_kwargs)).images[0]