testinggg / train.toml
ApacheOne's picture
Upload LoRA adapter + config + train.toml
a859e05 verified
# Output path for training runs. Each training run makes a new directory in here.
output_dir = "/content/outputs"
# Dataset config file.
dataset = "/content/run_cfg/dataset.toml"
# training settings
epochs = 1
micro_batch_size_per_gpu = 1
pipeline_stages = 1
gradient_accumulation_steps = 1
gradient_clipping = 1.0
warmup_steps = 0
# eval settings
eval_every_n_epochs = 1
eval_before_first_step = true
eval_micro_batch_size_per_gpu = 1
eval_gradient_accumulation_steps = 1
# misc settings
save_every_n_epochs = 1
checkpoint_every_n_minutes = 120
activation_checkpointing = true
partition_method = 'parameters'
save_dtype = 'bfloat16'
caching_batch_size = 1
steps_per_print = 1
# Flux2-Klein with NEW nested config structure.
[model]
dtype = 'bfloat16'
type = 'flux_2_klein'
[model.paths]
diffusers = "/content/models/flux2_klein_base_9b"
[model.vae]
latent_mode = 'sample'
[model.lora]
format = 'ai_toolkit_peft'
# --- DIFFUSERS PARITY PIPELINE ---
# Uses diffusers-style timestep sampling + SD3 weighting.
[model.train]
# Explicitly select the diffusers-parity pipeline.
pipeline = 'diffusers'
# These are the same knobs exposed by
# external/diffusers/examples/dreambooth/train_dreambooth_lora_flux2_klein.py
# when using compute_density_for_timestep_sampling + compute_loss_weighting_for_sd3.
# Diffusers-style knobs (mirrors the official diffusers trainer).
[model.train.diffusers]
# Enable diffusers-parity even if pipeline is not explicitly set.
enabled = true
# weighting_scheme: none | sigma_sqrt | logit_normal | mode | cosmap
weighting_scheme = 'none'
logit_mean = 0.0
logit_std = 1.0
mode_scale = 1.29
# num_train_timesteps typically 1000 in diffusers schedulers
num_train_timesteps = 1000
# The rest of the training config stays the same as native.
[adapter]
type = 'lora'
rank = 32
dtype = 'bfloat16'
[optimizer]
type = 'adamw_optimi'
lr = 2e-5
betas = [0.9, 0.99]
weight_decay = 0.01
eps = 1e-8
[monitoring]
enable_wandb = false
wandb_api_key = ''
wandb_tracker_name = ''
wandb_run_name = ''