Diffusers
Safetensors
IntrisicWeather-diffusers / pipeline_utils.py
BiliSakura's picture
Upload folder using huggingface_hub
c5cfae9 verified
Raw
History Blame Contribute Delete
3.72 kB
"""Helpers for loading transformer variants from ``transformer/<subfolder>/``."""
from __future__ import annotations
import importlib.util
from pathlib import Path
import torch
from diffusers.models.transformers import SD3Transformer2DModel
def calculate_shift(
image_seq_len: int,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
) -> float:
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
return image_seq_len * m + b
def set_flow_timesteps(
scheduler,
transformer,
num_inference_steps: int,
latent_height: int,
latent_width: int,
device: torch.device,
) -> None:
if scheduler.config.get("use_dynamic_shifting", False):
image_seq_len = (latent_height // transformer.config.patch_size) * (
latent_width // transformer.config.patch_size
)
mu = calculate_shift(
image_seq_len,
scheduler.config.get("base_image_seq_len", 256),
scheduler.config.get("max_image_seq_len", 4096),
scheduler.config.get("base_shift", 0.5),
scheduler.config.get("max_shift", 1.15),
)
scheduler.set_timesteps(num_inference_steps, device=device, mu=mu)
else:
scheduler.set_timesteps(num_inference_steps, device=device)
def resolve_repo_dir(pretrained_model_name_or_path: str | Path) -> Path:
return Path(pretrained_model_name_or_path).resolve()
def load_transformer_from_subfolder(
repo_dir: str | Path,
transformer_subfolder: str,
*,
dtype: torch.dtype = torch.bfloat16,
device: str | torch.device | None = None,
):
"""Load a transformer checkpoint from ``<repo_dir>/transformer/<transformer_subfolder>/``."""
repo_dir = resolve_repo_dir(repo_dir)
transformer_path = repo_dir / "transformer" / transformer_subfolder
if not transformer_path.is_dir():
raise FileNotFoundError(f"Transformer folder not found: {transformer_path}")
custom_module = transformer_path / "transformer_intrinsic_weather.py"
if custom_module.exists():
spec = importlib.util.spec_from_file_location("transformer_intrinsic_weather", custom_module)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot import custom transformer module: {custom_module}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
cls = module.IntrinsicWeatherSD3Transformer2DModel
transformer = cls.from_pretrained(
transformer_path.as_posix(),
torch_dtype=dtype,
local_files_only=True,
)
else:
transformer = SD3Transformer2DModel.from_pretrained(
transformer_path.as_posix(),
torch_dtype=dtype,
local_files_only=True,
)
if device is not None:
transformer = transformer.to(device)
return transformer
def resolve_transformer_lora_dir(repo_dir: str | Path, transformer_subfolder: str) -> Path | None:
"""Return ``transformer/<subfolder>/lora`` when present."""
lora_dir = resolve_repo_dir(repo_dir) / "transformer" / transformer_subfolder / "lora"
if lora_dir.is_dir() and any(lora_dir.glob("*.safetensors")):
return lora_dir
return None
def load_transformer_lora(pipe, repo_dir: str | Path, transformer_subfolder: str) -> bool:
"""Load LoRA weights bundled with a transformer variant. Returns True if loaded."""
lora_dir = resolve_transformer_lora_dir(repo_dir, transformer_subfolder)
if lora_dir is None:
return False
pipe.load_lora_weights(lora_dir.as_posix())
return True