|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
|
|
|
from cosmos_predict1.diffusion.networks.general_dit import GeneralDIT |
|
|
from cosmos_predict1.diffusion.networks.general_dit_multiview import MultiviewGeneralDIT |
|
|
from cosmos_predict1.utils.lazy_config import LazyCall as L |
|
|
from cosmos_predict1.utils.lazy_config import LazyDict |
|
|
|
|
|
FADITV2Config: LazyDict = L(GeneralDIT)( |
|
|
max_img_h=240, |
|
|
max_img_w=240, |
|
|
max_frames=128, |
|
|
in_channels=16, |
|
|
out_channels=16, |
|
|
patch_spatial=2, |
|
|
patch_temporal=1, |
|
|
model_channels=4096, |
|
|
block_config="FA-CA-MLP", |
|
|
num_blocks=28, |
|
|
num_heads=32, |
|
|
concat_padding_mask=True, |
|
|
pos_emb_cls="rope3d", |
|
|
pos_emb_learnable=False, |
|
|
pos_emb_interpolation="crop", |
|
|
block_x_format="THWBD", |
|
|
affline_emb_norm=True, |
|
|
use_adaln_lora=True, |
|
|
adaln_lora_dim=256, |
|
|
) |
|
|
|
|
|
|
|
|
FADITV2_14B_Config = copy.deepcopy(FADITV2Config) |
|
|
FADITV2_14B_Config.model_channels = 5120 |
|
|
FADITV2_14B_Config.num_heads = 40 |
|
|
FADITV2_14B_Config.num_blocks = 36 |
|
|
|
|
|
|
|
|
FADITV2_Multiview_Config: LazyDict = L(MultiviewGeneralDIT)( |
|
|
max_img_h=240, |
|
|
max_img_w=240, |
|
|
max_frames=128, |
|
|
in_channels=16, |
|
|
out_channels=16, |
|
|
patch_spatial=2, |
|
|
patch_temporal=1, |
|
|
model_channels=4096, |
|
|
block_config="FA-CA-MLP", |
|
|
num_blocks=28, |
|
|
num_heads=32, |
|
|
concat_padding_mask=True, |
|
|
pos_emb_cls="rope3d", |
|
|
pos_emb_learnable=False, |
|
|
pos_emb_interpolation="crop", |
|
|
block_x_format="THWBD", |
|
|
affline_emb_norm=True, |
|
|
use_adaln_lora=True, |
|
|
adaln_lora_dim=256, |
|
|
n_views=6, |
|
|
view_condition_dim=6, |
|
|
add_repeat_frame_embedding=True, |
|
|
rope_h_extrapolation_ratio=1.0, |
|
|
rope_w_extrapolation_ratio=1.0, |
|
|
rope_t_extrapolation_ratio=1.0, |
|
|
) |
|
|
|