| """Notebook-style diffusers loader for BiliSakura Hub models.""" |
|
|
| from __future__ import annotations |
|
|
| import gc |
| import inspect |
| import os |
| from pathlib import Path |
| from typing import Any, get_args, get_origin |
|
|
| import torch |
| from diffusers import DiffusionPipeline |
| import diffusers.pipelines.pipeline_utils as pipeline_utils |
| from huggingface_hub import snapshot_download |
|
|
| from model_catalog import ( |
| ModelProfile, |
| get_profile, |
| scheduler_choices_for_profile, |
| uses_native_scheduler, |
| ) |
|
|
|
|
| def _patch_diffusers_custom_pipeline_type_check() -> None: |
| """Work around diffusers 0.36 KeyError when custom pipelines omit parsed annotations.""" |
|
|
| if getattr(pipeline_utils, "_bilisakura_type_check_patch", False): |
| return |
|
|
| @classmethod |
| def patched_get_signature_types(cls): |
| signature_types = {} |
| for name, param in inspect.signature(cls.__init__).parameters.items(): |
| if name == "self": |
| continue |
| annotation = param.annotation |
| if annotation is inspect.Parameter.empty: |
| signature_types[name] = (inspect.Signature.empty,) |
| continue |
| origin = get_origin(annotation) |
| if inspect.isclass(annotation): |
| signature_types[name] = (annotation,) |
| elif origin is not None: |
| args = get_args(annotation) |
| signature_types[name] = args if args else (annotation,) |
| else: |
| signature_types[name] = (inspect.Signature.empty,) |
| return signature_types |
|
|
| original_from_pretrained = DiffusionPipeline.from_pretrained.__func__ |
|
|
| @classmethod |
| def from_pretrained_patched(cls, pretrained_model_name_or_path, *args, **kwargs): |
| original_get_signature_types = DiffusionPipeline._get_signature_types |
| DiffusionPipeline._get_signature_types = patched_get_signature_types |
| try: |
| return original_from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs) |
| finally: |
| DiffusionPipeline._get_signature_types = original_get_signature_types |
|
|
| DiffusionPipeline.from_pretrained = from_pretrained_patched |
| pipeline_utils._bilisakura_type_check_patch = True |
|
|
|
|
| _patch_diffusers_custom_pipeline_type_check() |
|
|
|
|
| LOCAL_MODELS_ROOT = Path(os.environ.get("LOCAL_MODELS_ROOT", "")).expanduser() |
| USE_LOCAL_MODELS = LOCAL_MODELS_ROOT.is_dir() |
| HF_ORG = os.environ.get("HF_MODEL_ORG", "BiliSakura") |
|
|
| class PipelineManager: |
| def __init__(self) -> None: |
| self._pipe: DiffusionPipeline | None = None |
| self._loaded_label: str | None = None |
| self._loaded_profile: ModelProfile | None = None |
| self._on_cuda: bool = False |
|
|
| @property |
| def loaded_label(self) -> str | None: |
| return self._loaded_label |
|
|
| @property |
| def loaded_profile(self) -> ModelProfile | None: |
| return self._loaded_profile |
|
|
| @property |
| def pipe(self) -> DiffusionPipeline | None: |
| return self._pipe |
|
|
| def _resolve_model_source(self, profile: ModelProfile) -> tuple[str, str]: |
| """Return local checkpoint path and a human-readable source label.""" |
| if USE_LOCAL_MODELS: |
| local_path = LOCAL_MODELS_ROOT / profile.collection / profile.variant |
| if local_path.is_dir(): |
| return str(local_path.resolve()), str(local_path) |
|
|
| repo_id = profile.hub_repo |
| cached_repo = snapshot_download( |
| repo_id, |
| repo_type="model", |
| allow_patterns=[f"{profile.variant}/**"], |
| ) |
| model_dir = Path(cached_repo) / profile.variant |
| return str(model_dir.resolve()), f"{repo_id} (subfolder `{profile.variant}`)" |
|
|
| def _resolve_dtype(self, profile: ModelProfile) -> torch.dtype: |
| return torch.bfloat16 if profile.dtype == "bfloat16" else torch.float32 |
|
|
| def _apply_post_load(self, pipe: DiffusionPipeline) -> None: |
| pipe.set_progress_bar_config(disable=True) |
|
|
| def unload(self) -> None: |
| if self._pipe is not None: |
| del self._pipe |
| self._pipe = None |
| self._loaded_label = None |
| self._loaded_profile = None |
| self._on_cuda = False |
| gc.collect() |
|
|
| def move_to_cuda(self) -> None: |
| if self._pipe is None or self._on_cuda: |
| return |
| self._pipe = self._pipe.to("cuda") |
| self._on_cuda = True |
|
|
| def load(self, collection: str, variant: str) -> tuple[str, ModelProfile]: |
| profile = get_profile(collection, variant) |
| label = profile.label |
| if self._loaded_label == label and self._pipe is not None: |
| return f"Model already loaded: `{label}`", profile |
|
|
| self.unload() |
| model_source, source_label = self._resolve_model_source(profile) |
| dtype = self._resolve_dtype(profile) |
| model_path = Path(model_source) |
|
|
| load_kwargs: dict[str, Any] = { |
| "trust_remote_code": True, |
| "torch_dtype": dtype, |
| } |
| if profile.use_custom_pipeline: |
| load_kwargs["custom_pipeline"] = str(model_path / "pipeline.py") |
|
|
| if USE_LOCAL_MODELS and (LOCAL_MODELS_ROOT / profile.collection / profile.variant).is_dir(): |
| load_kwargs["local_files_only"] = True |
|
|
| pipe = DiffusionPipeline.from_pretrained(model_source, **load_kwargs) |
| self._apply_post_load(pipe) |
|
|
| self._pipe = pipe |
| self._loaded_label = label |
| self._loaded_profile = profile |
| self._on_cuda = False |
| return f"Loaded `{label}` from `{source_label}`.", profile |
|
|
|
|
| PIPELINE_MANAGER = PipelineManager() |
|
|
|
|
| def current_scheduler_name(pipe: DiffusionPipeline) -> str: |
| return type(pipe.scheduler).__name__ |
|
|
|
|
| def scheduler_options_for_profile(profile: ModelProfile, pipe: DiffusionPipeline | None = None) -> tuple[list[str], str]: |
| if uses_native_scheduler(profile): |
| if pipe is not None: |
| name = current_scheduler_name(pipe) |
| return [name], name |
| return ["checkpoint"], "checkpoint" |
|
|
| swappable = scheduler_choices_for_profile(profile) |
| if pipe is not None: |
| loaded = current_scheduler_name(pipe) |
| choices = list(swappable) |
| if loaded not in choices: |
| choices = [loaded, *choices] |
| return choices, loaded |
|
|
| return ["checkpoint", *swappable], "checkpoint" |
|
|
|
|
| def swap_scheduler(pipe: DiffusionPipeline, scheduler_name: str, profile: ModelProfile) -> None: |
| if scheduler_name in {"", "checkpoint"}: |
| return |
| current = current_scheduler_name(pipe) |
| if scheduler_name == current: |
| return |
|
|
| import diffusers |
|
|
| if not hasattr(diffusers, scheduler_name): |
| raise ValueError(f"Unknown diffusers scheduler: {scheduler_name}") |
|
|
| scheduler_cls = getattr(diffusers, scheduler_name) |
| extra_kwargs = profile.scheduler_kwargs if profile.scheduler == scheduler_name else {} |
| pipe.scheduler = scheduler_cls.from_config(pipe.scheduler.config, **extra_kwargs) |
|
|
|
|
| def _to_int(value: Any, *, default: int = 0) -> int: |
| if value is None: |
| return default |
| if isinstance(value, bool): |
| return int(value) |
| if isinstance(value, (int, float)): |
| return int(value) |
| text = str(value).strip() |
| if not text: |
| return default |
| return int(float(text)) |
|
|
|
|
| def _to_float(value: Any, *, default: float = 0.0) -> float: |
| if value is None: |
| return default |
| if isinstance(value, (int, float)): |
| return float(value) |
| text = str(value).strip() |
| if not text: |
| return default |
| return float(text) |
|
|
|
|
| def build_inference_steps(profile: ModelProfile, steps: Any) -> int | list[int]: |
| steps = _to_int(steps, default=profile.default_steps) |
| if profile.steps_are_list: |
| per_stage = max(1, steps // 4) |
| return [per_stage, per_stage, per_stage, per_stage] |
| return steps |
|
|
|
|
| def _split_synonyms(label_text: str) -> list[str]: |
| return [part.strip() for part in label_text.split(",") if part.strip()] |
|
|
|
|
| def _get_pipe_id2label(pipe: DiffusionPipeline) -> dict[int, str]: |
| id2label: dict[int, str] | None = None |
| if hasattr(pipe, "id2label"): |
| raw = pipe.id2label |
| if isinstance(raw, dict) and raw: |
| id2label = {int(key): str(value) for key, value in raw.items()} |
| if id2label is None: |
| config = getattr(pipe, "config", None) |
| raw = getattr(config, "id2label", None) if config is not None else None |
| if isinstance(raw, dict) and raw: |
| id2label = {int(key): str(value) for key, value in raw.items()} |
| return id2label or {} |
|
|
|
|
| def _build_label2id(id2label: dict[int, str]) -> dict[str, int]: |
| label2id: dict[str, int] = {} |
| for class_id, value in id2label.items(): |
| synonyms = _split_synonyms(value) |
| if not synonyms: |
| continue |
| for synonym in synonyms: |
| label2id[synonym] = int(class_id) |
| label2id[value.strip()] = int(class_id) |
| return label2id |
|
|
|
|
| def resolve_class_labels( |
| pipe: DiffusionPipeline, |
| class_label: str, |
| *, |
| default: str, |
| ) -> int | str: |
| """Resolve a class name or id using the loaded model's id2label synonyms.""" |
| label = str(class_label or "").strip() or default |
| if label.isdigit(): |
| return int(label) |
| if label.replace(".", "", 1).isdigit(): |
| return int(float(label)) |
|
|
| id2label = _get_pipe_id2label(pipe) |
| if not id2label: |
| return label |
|
|
| label2id = _build_label2id(id2label) |
| if label in label2id: |
| return label2id[label] |
|
|
| for part in _split_synonyms(label): |
| if part in label2id: |
| return label2id[part] |
|
|
| normalized = label.casefold() |
| for class_id, value in id2label.items(): |
| if value.strip().casefold() == normalized: |
| return int(class_id) |
| for part in _split_synonyms(value): |
| if part.casefold() == normalized: |
| return int(class_id) |
|
|
| if hasattr(pipe, "get_label_ids"): |
| for candidate in _split_synonyms(label): |
| try: |
| return pipe.get_label_ids(candidate)[0] |
| except (ValueError, TypeError): |
| continue |
|
|
| return label |
|
|
|
|
| def primary_label_for_id(pipe: DiffusionPipeline, class_id: int, *, fallback: str) -> str: |
| """Return the first synonym from id2label for a class id.""" |
| id2label = _get_pipe_id2label(pipe) |
| value = id2label.get(int(class_id)) |
| if not value: |
| return fallback |
| synonyms = _split_synonyms(value) |
| return synonyms[0] if synonyms else fallback |
|
|
|
|
| def default_class_label_for_pipe(pipe: DiffusionPipeline, profile: ModelProfile) -> str: |
| """Pick a sensible default label using id2label synonyms when available.""" |
| id2label = _get_pipe_id2label(pipe) |
| if not id2label: |
| return profile.default_class_label |
|
|
| preferred_ids = (207, 285, 281) |
| for class_id in preferred_ids: |
| if class_id in id2label: |
| return primary_label_for_id(pipe, class_id, fallback=profile.default_class_label) |
|
|
| for value in id2label.values(): |
| for synonym in _split_synonyms(value): |
| if synonym.casefold() == profile.default_class_label.casefold(): |
| return synonym |
|
|
| first_value = next(iter(id2label.values()), profile.default_class_label) |
| synonyms = _split_synonyms(first_value) |
| return synonyms[0] if synonyms else profile.default_class_label |
|
|
|
|
| def _filter_call_kwargs(pipe: DiffusionPipeline, call_kwargs: dict[str, Any]) -> dict[str, Any]: |
| """Drop kwargs that the pipeline __call__ does not accept (e.g. height/width for iMF/pMF).""" |
| try: |
| params = inspect.signature(pipe.__call__).parameters |
| except (TypeError, ValueError): |
| return call_kwargs |
| if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in params.values()): |
| return call_kwargs |
| accepted = set(params.keys()) |
| return {key: value for key, value in call_kwargs.items() if key in accepted} |
|
|
|
|
| def run_inference( |
| profile: ModelProfile, |
| pipe: DiffusionPipeline, |
| *, |
| class_label: str, |
| seed: int, |
| num_steps: int, |
| guidance_scale: float, |
| height: int, |
| width: int, |
| scheduler_name: str | None = None, |
| extra_kwargs: dict[str, Any] | None = None, |
| ) -> Any: |
| seed = _to_int(seed, default=profile.default_seed) |
| num_steps = _to_int(num_steps, default=profile.default_steps) |
| guidance_scale = _to_float(guidance_scale, default=profile.default_guidance) |
| height = _to_int(height, default=0) |
| width = _to_int(width, default=0) |
|
|
| if scheduler_name and not uses_native_scheduler(profile): |
| swap_scheduler(pipe, str(scheduler_name), profile) |
|
|
| generator = torch.Generator(device="cuda").manual_seed(seed) |
| call_kwargs: dict[str, Any] = { |
| "num_inference_steps": build_inference_steps(profile, num_steps), |
| "guidance_scale": guidance_scale, |
| "generator": generator, |
| } |
| call_kwargs.update(extra_kwargs if extra_kwargs is not None else profile.extra_call_kwargs) |
|
|
| call_kwargs["class_labels"] = resolve_class_labels( |
| pipe, |
| class_label, |
| default=profile.default_class_label, |
| ) |
|
|
| native = profile.infer_resolution() |
| if height > 0 and width > 0: |
| if profile.collection != "ADM-diffusers" or height != native or width != native: |
| call_kwargs["height"] = height |
| call_kwargs["width"] = width |
|
|
| return pipe(**_filter_call_kwargs(pipe, call_kwargs)).images[0] |
|
|