File size: 2,341 Bytes
a4654b2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | from dataclasses import dataclass, field
from pathlib import Path
from infer_runtime.infer_config import InferConfig
def _resolve_root() -> Path:
here = Path(__file__).resolve().parent
if (here / "transformer").exists() and (here / "vae").exists() and (here / "JoyAI-Image-Und").exists():
return here
raise ValueError(
"Place this config file directly inside the checkpoint root."
)
_ROOT = _resolve_root()
@dataclass
class JoyAIImageInferConfig(InferConfig):
dit_arch_config: dict = field(
default_factory=lambda: {
"target": "modules.models.Transformer3DModel",
"params": {
"hidden_size": 4096,
"in_channels": 16,
"heads_num": 32,
"mm_double_blocks_depth": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"rope_dim_list": [16, 56, 56],
"text_states_dim": 4096,
"rope_type": "rope",
"dit_modulation_type": "wanx",
"theta": 10000,
"attn_backend": "flash_attn",
},
}
)
vae_arch_config: dict = field(
default_factory=lambda: {
"target": "modules.models.WanxVAE",
"params": {
"pretrained": str(_ROOT / "vae" / "Wan2.1_VAE.pth"),
},
}
)
text_encoder_arch_config: dict = field(
default_factory=lambda: {
"target": "modules.models.load_text_encoder",
"params": {
"text_encoder_ckpt": str(_ROOT / "JoyAI-Image-Und"),
},
}
)
scheduler_arch_config: dict = field(
default_factory=lambda: {
"target": "modules.models.FlowMatchDiscreteScheduler",
"params": {
"num_train_timesteps": 1000,
"shift": 4.0,
},
}
)
dit_precision: str = "bf16"
vae_precision: str = "bf16"
text_encoder_precision: str = "bf16"
text_token_max_length: int = 2048
# Keep these fields visible in the active config because they control multi-GPU inference.
hsdp_shard_dim: int = 1
reshard_after_forward: bool = False
use_fsdp_inference: bool = False
cpu_offload: bool = False
pin_cpu_memory: bool = False
|