BLIP3o-4B-v3-TEST / diffusion_auto.py
orrzohar's picture
Upload diffusion_auto.py with huggingface_hub
135c97e verified
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import torch
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from safetensors.torch import load_file
from .nextdit_crossattn import NextDiTCrossAttn, NextDiTCrossAttnConfig
@dataclass
class DiffusionConfig:
weights_path: Optional[str] = None
scheduler_path: Optional[str] = None
dim: int = 1792
n_layers: int = 24
n_heads: int = 28
n_kv_heads: int = 28
latent_embedding_size: int = 3584
input_size: int = 8
patch_size: int = 1
in_channels: int = 1792
@dataclass
class DiffusionBundle:
dit: NextDiTCrossAttn
scheduler: FlowMatchEulerDiscreteScheduler
load_info: dict
class AutoDiffusionModel:
"""Utility to materialize a NextDiT cross-attention model and scheduler."""
@classmethod
def from_config(
cls,
config: DiffusionConfig,
*,
device: Optional[str] = None,
torch_dtype: Optional[torch.dtype] = None,
) -> DiffusionBundle:
dit_config = NextDiTCrossAttnConfig(
input_size=config.input_size,
patch_size=config.patch_size,
in_channels=config.in_channels,
dim=config.dim,
n_layers=config.n_layers,
n_heads=config.n_heads,
n_kv_heads=config.n_kv_heads,
latent_embedding_size=config.latent_embedding_size,
)
model = NextDiTCrossAttn(dit_config)
load_info = {"missing_keys": (), "unexpected_keys": ()}
if config.weights_path:
weights_path = Path(config.weights_path)
state_dict = load_file(weights_path)
load_result = model.load_state_dict(state_dict, strict=False)
load_info = {
"missing_keys": load_result.missing_keys,
"unexpected_keys": load_result.unexpected_keys,
}
if torch_dtype is not None:
model = model.to(dtype=torch_dtype)
if device is not None:
model = model.to(device=device)
if config.scheduler_path:
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(config.scheduler_path)
else:
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
"Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler"
)
return DiffusionBundle(dit=model, scheduler=scheduler, load_info=load_info)
__all__ = ["DiffusionConfig", "DiffusionBundle", "AutoDiffusionModel"]