flickr8k-backend / core /config.py
Rohan3's picture
Updated: VAE, UNet, config, text embeddings, model and main
a625e96
tqdm_colors = ["#0867C5", "green", "blue", "red", "yellow", "#6707B0", "#FFDD22", "#00FFEF", "#DC442C"]
data_dir = "./backend/core/data"
image_res = 256
image_dir = f"{data_dir}/Images"
resized_img_dir = f"{data_dir}/ResizedImages_{image_res}"
vae_batch_size = 4
vae_group_size = 4
vae_num_epochs = 1000
vae_stopping_patience = 30
vae_latent_channels = 4 # 16, 8
vae_latent_dim = 32
vae_beta_kld = 1e-3
vae_optim_lr = 5e-5
vae_lambda_tvl = 0 # 1e-3
vae_lpips_weight = 1e-1
vae_dropout = 0.
vae_checkpoint_dir = "./backend/core/checkpoints/vae"
# vae_weight = f"attn_test_best_{image_res}_3_32_128_256_512_{vae_latent_channels}_beta_{vae_beta_kld}_tvl_{vae_lambda_tvl}_batch_{vae_batch_size}_lr_{vae_optim_lr}.pth"
vae_weight = f"ssim_lpips_attn_test_best_{image_res}_3_32_128_256_512_{vae_latent_channels}_beta_{vae_beta_kld}_tvl_{vae_lambda_tvl}_batch_{vae_batch_size}_lr_{vae_optim_lr}.pth"
latent_dir = f"./backend/core/latents_{image_res}"
latent_scaled_dir = f"./backend/core/latents_scaled_{image_res}"
latent_norm_dir = f"./backend/core/latents_norm_{image_res}"
latent_recon_images = f"./backend/core/recon_img_{vae_latent_channels}_{vae_latent_dim}_{vae_latent_dim}_res_{image_res}"
# latent_mu = 0.021192772313952446; latent_std = 0.9767765402793884 # Latents
# latent_mu = 0.0007418220047838986; latent_std = 0.9822604060173035 # Latents scale = 1.0180599689483643
latent_mu = -0.08573896437883377; latent_std = 1.2452856302261353 # Latents scale = 0.8030286431312561
# Mean over latent: 0.021192772313952446
# STD over latent: 0.9767765402793884
# Latent Scale: 1.023775577545166
# πŸ”₯ Epoch 104: Avg Train Loss=141.103653 | Recon=129.274836 | KLD=11.828816 | TVL=0.000000 | LPIPS=0.000000 | SSIM=0.000000
# Epoch 104/1000: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 203/203 [00:06<00:00, 31.10it/s]
# πŸ§ͺ Test Loss = 181.512993
# Saved best model at 104
# Progress: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:00<00:00, 4.04it/s]
# Scaled latents: Mean over latent: -0.06069774180650711; STD over latent: 0.9999988079071045, Latent Scale: 1.0000011920928955
# Normalized latents: Mean over latent: 6.175390154794513e-08, STD over latent: 1.0000004768371582, Latent Scale: 0.9999995231628418
# For first 16 latent images
# latent_mu = -0.0032552392221987247
# latent_std = 0.9028854966163635
# ////////////////////////////////////////////////
text_captions_dir = f"{data_dir}/captions.txt"
unet_batch_size = 32
embedding_dim = 1024
embedding_dir = f"./backend/core/embeddings_77_{embedding_dim}"
null_embedding_dir = f"./backend/core/embeddings_77_{embedding_dim}/null_embedding.pt"
embedding_model = "ViT-g-14" # or "ViT-B-16"
embedding_pretrained = "laion2b_s12b_b42k" # or "openai"
unet_pred_type = "v_prediction" # v_prediction or epsilon
unet_checkpoint_dir = f"./backend/core/checkpoints/ldm"
unet_max_steps = 256_000
# new_lr = old_lr * (batch_size_new / batch_size_old)
unet_optim_lr = 1e-4 # Changed to 5e-5 after 1103080 steps
unet_group_size = 32
unet_beta_schedule = "squaredcos_cap_v2" # "linear" or "squaredcos_cap_v2"
unet_dropout = 0.
attn_dropout = 0.
unet_train_timesteps = 1000
ddim_guidace_scale = 8
ddim_num_sampling_steps = 100
ddim_img_dir = "./backend/core/ddim_recon_img"
unet_val_embeddings_dir = "./backend/core/embeddings_val"