File size: 2,116 Bytes
fb45cfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
model: 
  transport:
    target: tim.schedulers.transports.OT_FM
    params:
      P_mean: 0.0
      P_std: 1.6 
      sigma_d: 1.0
  unified_dcm_loss: 
    diffusion_ratio: 0.5
    consistency_ratio: 0.1
    derivative_type: dde
    differential_epsilon: 0.005
    weight_time_type: sqrt
    weight_time_tangent: True
  network:  
    target: tim.models.t2i.tim_model.TiM
    params:
      input_size: 16
      patch_size: 1
      in_channels: 32
      depth: 28
      hidden_size: 1152
      cap_feat_dim: 1152
      num_heads: 16
      encoder_depth: 8
      qk_norm: True
      z_dim: 768
      new_condition: t-r
      use_new_embed: True
      distance_aware: True
      lora_hidden_size: 384
  # pretrained_vae:
  vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
  # text encoder
  text_encoder_dir: google/gemma-3-1b-it
  proportion_empty_prompts: 0.1
  use_last_hidden_state: True
  max_seq_length: 256
  # repa encoder
  enc_dir: checkpoints/radio/radio-v2.5-b_half.pth.tar
  proj_coeff: 1.0
  # ema
  use_ema: True
  ema_decay: 0.9999
  
data:
  data_type: image_ms
  dataset:
    root_dir: datasets/t2i_toy_dataset
    packed_json: datasets/t2i_toy_dataset/bucket_sampler.json
    jsonl_dir: datasets/t2i_toy_dataset/data_info.jsonl
  dataloader:
    num_workers: 4
    batch_size: 128  # Batch size (per device) for the training dataloader.

  
training:
  tracker: null
  max_train_steps: 500000
  checkpointing_steps: 1000
  checkpoints_total_limit: 2
  resume_from_checkpoint: latest
  learning_rate: 1.0e-4
  learning_rate_base_batch_size: 512
  scale_lr: True
  lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
  lr_warmup_steps: 0
  gradient_accumulation_steps: 1
  optimizer: 
    target: torch.optim.AdamW
    params:
      # betas: ${tuple:0.9, 0.999}
      betas: [0.9, 0.95]
      weight_decay: 1.0e-2
      eps: 1.0e-6
  max_grad_norm: 1.0
  proportion_empty_prompts: 0.0
  mixed_precision: bf16 # ["no", "fp16", "bf16"]
  allow_tf32: True 
  validation_steps: 500
  checkpoint_list: [100000, 200000, 300000, 400000]