File size: 2,116 Bytes
fb45cfb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
model:
transport:
target: tim.schedulers.transports.OT_FM
params:
P_mean: 0.0
P_std: 1.6
sigma_d: 1.0
unified_dcm_loss:
diffusion_ratio: 0.5
consistency_ratio: 0.1
derivative_type: dde
differential_epsilon: 0.005
weight_time_type: sqrt
weight_time_tangent: True
network:
target: tim.models.t2i.tim_model.TiM
params:
input_size: 16
patch_size: 1
in_channels: 32
depth: 28
hidden_size: 1152
cap_feat_dim: 1152
num_heads: 16
encoder_depth: 8
qk_norm: True
z_dim: 768
new_condition: t-r
use_new_embed: True
distance_aware: True
lora_hidden_size: 384
# pretrained_vae:
vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
# text encoder
text_encoder_dir: google/gemma-3-1b-it
proportion_empty_prompts: 0.1
use_last_hidden_state: True
max_seq_length: 256
# repa encoder
enc_dir: checkpoints/radio/radio-v2.5-b_half.pth.tar
proj_coeff: 1.0
# ema
use_ema: True
ema_decay: 0.9999
data:
data_type: image_ms
dataset:
root_dir: datasets/t2i_toy_dataset
packed_json: datasets/t2i_toy_dataset/bucket_sampler.json
jsonl_dir: datasets/t2i_toy_dataset/data_info.jsonl
dataloader:
num_workers: 4
batch_size: 128 # Batch size (per device) for the training dataloader.
training:
tracker: null
max_train_steps: 500000
checkpointing_steps: 1000
checkpoints_total_limit: 2
resume_from_checkpoint: latest
learning_rate: 1.0e-4
learning_rate_base_batch_size: 512
scale_lr: True
lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
lr_warmup_steps: 0
gradient_accumulation_steps: 1
optimizer:
target: torch.optim.AdamW
params:
# betas: ${tuple:0.9, 0.999}
betas: [0.9, 0.95]
weight_decay: 1.0e-2
eps: 1.0e-6
max_grad_norm: 1.0
proportion_empty_prompts: 0.0
mixed_precision: bf16 # ["no", "fp16", "bf16"]
allow_tf32: True
validation_steps: 500
checkpoint_list: [100000, 200000, 300000, 400000]
|