|
|
from __future__ import annotations |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
from .nextdit_crossattn import NextDiTCrossAttn, NextDiTCrossAttnConfig |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DiffusionConfig: |
|
|
weights_path: Optional[str] = None |
|
|
scheduler_path: Optional[str] = None |
|
|
dim: int = 1792 |
|
|
n_layers: int = 24 |
|
|
n_heads: int = 28 |
|
|
n_kv_heads: int = 28 |
|
|
latent_embedding_size: int = 3584 |
|
|
input_size: int = 8 |
|
|
patch_size: int = 1 |
|
|
in_channels: int = 1792 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DiffusionBundle: |
|
|
dit: NextDiTCrossAttn |
|
|
scheduler: FlowMatchEulerDiscreteScheduler |
|
|
load_info: dict |
|
|
|
|
|
|
|
|
class AutoDiffusionModel: |
|
|
"""Utility to materialize a NextDiT cross-attention model and scheduler.""" |
|
|
|
|
|
@classmethod |
|
|
def from_config( |
|
|
cls, |
|
|
config: DiffusionConfig, |
|
|
*, |
|
|
device: Optional[str] = None, |
|
|
torch_dtype: Optional[torch.dtype] = None, |
|
|
) -> DiffusionBundle: |
|
|
dit_config = NextDiTCrossAttnConfig( |
|
|
input_size=config.input_size, |
|
|
patch_size=config.patch_size, |
|
|
in_channels=config.in_channels, |
|
|
dim=config.dim, |
|
|
n_layers=config.n_layers, |
|
|
n_heads=config.n_heads, |
|
|
n_kv_heads=config.n_kv_heads, |
|
|
latent_embedding_size=config.latent_embedding_size, |
|
|
) |
|
|
|
|
|
model = NextDiTCrossAttn(dit_config) |
|
|
load_info = {"missing_keys": (), "unexpected_keys": ()} |
|
|
|
|
|
if config.weights_path: |
|
|
weights_path = Path(config.weights_path) |
|
|
state_dict = load_file(weights_path) |
|
|
load_result = model.load_state_dict(state_dict, strict=False) |
|
|
load_info = { |
|
|
"missing_keys": load_result.missing_keys, |
|
|
"unexpected_keys": load_result.unexpected_keys, |
|
|
} |
|
|
|
|
|
if torch_dtype is not None: |
|
|
model = model.to(dtype=torch_dtype) |
|
|
if device is not None: |
|
|
model = model.to(device=device) |
|
|
|
|
|
if config.scheduler_path: |
|
|
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(config.scheduler_path) |
|
|
else: |
|
|
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( |
|
|
"Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler" |
|
|
) |
|
|
|
|
|
return DiffusionBundle(dit=model, scheduler=scheduler, load_info=load_info) |
|
|
|
|
|
|
|
|
__all__ = ["DiffusionConfig", "DiffusionBundle", "AutoDiffusionModel"] |
|
|
|
|
|
|