from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import Optional import torch from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from safetensors.torch import load_file from .nextdit_crossattn import NextDiTCrossAttn, NextDiTCrossAttnConfig @dataclass class DiffusionConfig: weights_path: Optional[str] = None scheduler_path: Optional[str] = None dim: int = 1792 n_layers: int = 24 n_heads: int = 28 n_kv_heads: int = 28 latent_embedding_size: int = 3584 input_size: int = 8 patch_size: int = 1 in_channels: int = 1792 @dataclass class DiffusionBundle: dit: NextDiTCrossAttn scheduler: FlowMatchEulerDiscreteScheduler load_info: dict class AutoDiffusionModel: """Utility to materialize a NextDiT cross-attention model and scheduler.""" @classmethod def from_config( cls, config: DiffusionConfig, *, device: Optional[str] = None, torch_dtype: Optional[torch.dtype] = None, ) -> DiffusionBundle: dit_config = NextDiTCrossAttnConfig( input_size=config.input_size, patch_size=config.patch_size, in_channels=config.in_channels, dim=config.dim, n_layers=config.n_layers, n_heads=config.n_heads, n_kv_heads=config.n_kv_heads, latent_embedding_size=config.latent_embedding_size, ) model = NextDiTCrossAttn(dit_config) load_info = {"missing_keys": (), "unexpected_keys": ()} if config.weights_path: weights_path = Path(config.weights_path) state_dict = load_file(weights_path) load_result = model.load_state_dict(state_dict, strict=False) load_info = { "missing_keys": load_result.missing_keys, "unexpected_keys": load_result.unexpected_keys, } if torch_dtype is not None: model = model.to(dtype=torch_dtype) if device is not None: model = model.to(device=device) if config.scheduler_path: scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(config.scheduler_path) else: scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( "Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler" ) return DiffusionBundle(dit=model, scheduler=scheduler, load_info=load_info) __all__ = ["DiffusionConfig", "DiffusionBundle", "AutoDiffusionModel"]