"""Helpers for loading transformer variants from ``transformer//``.""" from __future__ import annotations import importlib.util from pathlib import Path import torch from diffusers.models.transformers import SD3Transformer2DModel def calculate_shift( image_seq_len: int, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.15, ) -> float: m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len return image_seq_len * m + b def set_flow_timesteps( scheduler, transformer, num_inference_steps: int, latent_height: int, latent_width: int, device: torch.device, ) -> None: if scheduler.config.get("use_dynamic_shifting", False): image_seq_len = (latent_height // transformer.config.patch_size) * ( latent_width // transformer.config.patch_size ) mu = calculate_shift( image_seq_len, scheduler.config.get("base_image_seq_len", 256), scheduler.config.get("max_image_seq_len", 4096), scheduler.config.get("base_shift", 0.5), scheduler.config.get("max_shift", 1.15), ) scheduler.set_timesteps(num_inference_steps, device=device, mu=mu) else: scheduler.set_timesteps(num_inference_steps, device=device) def resolve_repo_dir(pretrained_model_name_or_path: str | Path) -> Path: return Path(pretrained_model_name_or_path).resolve() def load_transformer_from_subfolder( repo_dir: str | Path, transformer_subfolder: str, *, dtype: torch.dtype = torch.bfloat16, device: str | torch.device | None = None, ): """Load a transformer checkpoint from ``/transformer//``.""" repo_dir = resolve_repo_dir(repo_dir) transformer_path = repo_dir / "transformer" / transformer_subfolder if not transformer_path.is_dir(): raise FileNotFoundError(f"Transformer folder not found: {transformer_path}") custom_module = transformer_path / "transformer_intrinsic_weather.py" if custom_module.exists(): spec = importlib.util.spec_from_file_location("transformer_intrinsic_weather", custom_module) if spec is None or spec.loader is None: raise ImportError(f"Cannot import custom transformer module: {custom_module}") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) cls = module.IntrinsicWeatherSD3Transformer2DModel transformer = cls.from_pretrained( transformer_path.as_posix(), torch_dtype=dtype, local_files_only=True, ) else: transformer = SD3Transformer2DModel.from_pretrained( transformer_path.as_posix(), torch_dtype=dtype, local_files_only=True, ) if device is not None: transformer = transformer.to(device) return transformer def resolve_transformer_lora_dir(repo_dir: str | Path, transformer_subfolder: str) -> Path | None: """Return ``transformer//lora`` when present.""" lora_dir = resolve_repo_dir(repo_dir) / "transformer" / transformer_subfolder / "lora" if lora_dir.is_dir() and any(lora_dir.glob("*.safetensors")): return lora_dir return None def load_transformer_lora(pipe, repo_dir: str | Path, transformer_subfolder: str) -> bool: """Load LoRA weights bundled with a transformer variant. Returns True if loaded.""" lora_dir = resolve_transformer_lora_dir(repo_dir, transformer_subfolder) if lora_dir is None: return False pipe.load_lora_weights(lora_dir.as_posix()) return True