File size: 2,583 Bytes
135c97e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import torch
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from safetensors.torch import load_file
from .nextdit_crossattn import NextDiTCrossAttn, NextDiTCrossAttnConfig
@dataclass
class DiffusionConfig:
weights_path: Optional[str] = None
scheduler_path: Optional[str] = None
dim: int = 1792
n_layers: int = 24
n_heads: int = 28
n_kv_heads: int = 28
latent_embedding_size: int = 3584
input_size: int = 8
patch_size: int = 1
in_channels: int = 1792
@dataclass
class DiffusionBundle:
dit: NextDiTCrossAttn
scheduler: FlowMatchEulerDiscreteScheduler
load_info: dict
class AutoDiffusionModel:
"""Utility to materialize a NextDiT cross-attention model and scheduler."""
@classmethod
def from_config(
cls,
config: DiffusionConfig,
*,
device: Optional[str] = None,
torch_dtype: Optional[torch.dtype] = None,
) -> DiffusionBundle:
dit_config = NextDiTCrossAttnConfig(
input_size=config.input_size,
patch_size=config.patch_size,
in_channels=config.in_channels,
dim=config.dim,
n_layers=config.n_layers,
n_heads=config.n_heads,
n_kv_heads=config.n_kv_heads,
latent_embedding_size=config.latent_embedding_size,
)
model = NextDiTCrossAttn(dit_config)
load_info = {"missing_keys": (), "unexpected_keys": ()}
if config.weights_path:
weights_path = Path(config.weights_path)
state_dict = load_file(weights_path)
load_result = model.load_state_dict(state_dict, strict=False)
load_info = {
"missing_keys": load_result.missing_keys,
"unexpected_keys": load_result.unexpected_keys,
}
if torch_dtype is not None:
model = model.to(dtype=torch_dtype)
if device is not None:
model = model.to(device=device)
if config.scheduler_path:
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(config.scheduler_path)
else:
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
"Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler"
)
return DiffusionBundle(dit=model, scheduler=scheduler, load_info=load_info)
__all__ = ["DiffusionConfig", "DiffusionBundle", "AutoDiffusionModel"]
|