Diffusers
Safetensors
File size: 3,715 Bytes
c5cfae9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Helpers for loading transformer variants from ``transformer/<subfolder>/``."""

from __future__ import annotations

import importlib.util
from pathlib import Path

import torch
from diffusers.models.transformers import SD3Transformer2DModel


def calculate_shift(
    image_seq_len: int,
    base_seq_len: int = 256,
    max_seq_len: int = 4096,
    base_shift: float = 0.5,
    max_shift: float = 1.15,
) -> float:
    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
    b = base_shift - m * base_seq_len
    return image_seq_len * m + b


def set_flow_timesteps(
    scheduler,
    transformer,
    num_inference_steps: int,
    latent_height: int,
    latent_width: int,
    device: torch.device,
) -> None:
    if scheduler.config.get("use_dynamic_shifting", False):
        image_seq_len = (latent_height // transformer.config.patch_size) * (
            latent_width // transformer.config.patch_size
        )
        mu = calculate_shift(
            image_seq_len,
            scheduler.config.get("base_image_seq_len", 256),
            scheduler.config.get("max_image_seq_len", 4096),
            scheduler.config.get("base_shift", 0.5),
            scheduler.config.get("max_shift", 1.15),
        )
        scheduler.set_timesteps(num_inference_steps, device=device, mu=mu)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device)


def resolve_repo_dir(pretrained_model_name_or_path: str | Path) -> Path:
    return Path(pretrained_model_name_or_path).resolve()


def load_transformer_from_subfolder(
    repo_dir: str | Path,
    transformer_subfolder: str,
    *,
    dtype: torch.dtype = torch.bfloat16,
    device: str | torch.device | None = None,
):
    """Load a transformer checkpoint from ``<repo_dir>/transformer/<transformer_subfolder>/``."""
    repo_dir = resolve_repo_dir(repo_dir)
    transformer_path = repo_dir / "transformer" / transformer_subfolder
    if not transformer_path.is_dir():
        raise FileNotFoundError(f"Transformer folder not found: {transformer_path}")

    custom_module = transformer_path / "transformer_intrinsic_weather.py"
    if custom_module.exists():
        spec = importlib.util.spec_from_file_location("transformer_intrinsic_weather", custom_module)
        if spec is None or spec.loader is None:
            raise ImportError(f"Cannot import custom transformer module: {custom_module}")
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)
        cls = module.IntrinsicWeatherSD3Transformer2DModel
        transformer = cls.from_pretrained(
            transformer_path.as_posix(),
            torch_dtype=dtype,
            local_files_only=True,
        )
    else:
        transformer = SD3Transformer2DModel.from_pretrained(
            transformer_path.as_posix(),
            torch_dtype=dtype,
            local_files_only=True,
        )

    if device is not None:
        transformer = transformer.to(device)
    return transformer


def resolve_transformer_lora_dir(repo_dir: str | Path, transformer_subfolder: str) -> Path | None:
    """Return ``transformer/<subfolder>/lora`` when present."""
    lora_dir = resolve_repo_dir(repo_dir) / "transformer" / transformer_subfolder / "lora"
    if lora_dir.is_dir() and any(lora_dir.glob("*.safetensors")):
        return lora_dir
    return None


def load_transformer_lora(pipe, repo_dir: str | Path, transformer_subfolder: str) -> bool:
    """Load LoRA weights bundled with a transformer variant. Returns True if loaded."""
    lora_dir = resolve_transformer_lora_dir(repo_dir, transformer_subfolder)
    if lora_dir is None:
        return False
    pipe.load_lora_weights(lora_dir.as_posix())
    return True