Lexa
Initial commit
3d79eb3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
#
from lcm.models.two_tower_diffusion_lcm.builder import (
DenoiserConfig,
EncoderFrontendConfig,
TransformerConfig,
TwoTowerDiffusionLCModelConfig,
lcm_arch,
)
from lcm.nn.projection import ProjectionConfig
from lcm.nn.schedulers import DDIMSchedulerConfig
@lcm_arch("toy_two_tower_diffusion_lcm")
def toy_lcm() -> TwoTowerDiffusionLCModelConfig:
return TwoTowerDiffusionLCModelConfig(
context_encoder=TransformerConfig(num_layers=2),
denoiser=DenoiserConfig(num_layers=2),
# TODO change normalizer name to align with the normalizer instructions
sonar_normalizer_name="dummy_sonar_normalizer_A",
)
@lcm_arch("arch_lexa_lcm_pre0_toy")
def lexa_lcm_pre0_toy() -> TwoTowerDiffusionLCModelConfig:
return TwoTowerDiffusionLCModelConfig(
context_encoder=TransformerConfig(num_layers=2),
denoiser=DenoiserConfig(num_layers=2),
sonar_normalizer_name="sonar_normalizer_wikipedia_en_1m",
trained_with_cf_guidance=True,
)
@lcm_arch("arch_lexa_lcm_pre0_minimal")
def lexa_lcm_pre0_minimal() -> TwoTowerDiffusionLCModelConfig:
"""4-layer encoder / 6-layer denoiser / model dim 768"""
model_dim: int = 768 # Reduced from 2048 to 768
num_attn_heads: int = 12 # Reduced from 16 to 12
return TwoTowerDiffusionLCModelConfig(
model_dim=model_dim,
max_seq_len=2048,
frontend=EncoderFrontendConfig(),
context_encoder=TransformerConfig(
num_layers=3,
ffn_inner_dim=3 * model_dim, # Reduced from 4 * model_dim to 3 * model_dim
num_attn_heads=num_attn_heads,
final_dropout_p=0.0,
attention_dropout_p=0.0,
dropout_p=0.1,
mha_output_proj_bias=True,
use_swiglu=True,
layer_normalization_style="rms",
pos_embedding_style="rope",
),
denoiser=DenoiserConfig(
num_layers=6, # Reduced from 13 to 6
timestep_embed_dim=model_dim,
ffn_inner_dim=3 * model_dim, # Reduced from 4 * model_dim to 3 * model_dim
pos_embedding_style="none",
num_attn_heads=num_attn_heads,
final_dropout_p=0.0,
attention_dropout_p=0.0,
dropout_p=0.1,
mha_output_proj_bias=True,
use_swiglu=True,
layer_normalization_style="rms",
pre_denoiser=ProjectionConfig(),
post_denoiser=ProjectionConfig(),
),
sonar_normalizer_name="sonar_normalizer_wikipedia_en_1m",
trained_with_cf_guidance=True,
noise_scheduler=DDIMSchedulerConfig(num_diffusion_train_steps=100),
)
@lcm_arch("arch_lexa_lcm_pre0")
def lexa_lcm_pre0() -> TwoTowerDiffusionLCModelConfig:
"""4-layer encoder / 10-layer denoiser / model dim 1024
Parameter Size: 287,880,192"""
model_dim: int = 1024 # Reduced from 2048 to 1024
num_attn_heads: int = 16
return TwoTowerDiffusionLCModelConfig(
model_dim=model_dim,
max_seq_len=2048,
frontend=EncoderFrontendConfig(),
context_encoder=TransformerConfig(
num_layers=4, # Reduced from 5 to 4
ffn_inner_dim=3 * model_dim, # Reduced from 4 * model_dim to 3 * model_dim
num_attn_heads=num_attn_heads,
final_dropout_p=0.0,
attention_dropout_p=0.0,
dropout_p=0.1,
mha_output_proj_bias=True,
use_swiglu=True,
layer_normalization_style="rms",
pos_embedding_style="rope",
),
denoiser=DenoiserConfig(
num_layers=10, # Reduced from 13 to 10
timestep_embed_dim=model_dim,
ffn_inner_dim=3 * model_dim, # Reduced from 4 * model_dim to 3 * model_dim
pos_embedding_style="none",
num_attn_heads=num_attn_heads,
final_dropout_p=0.0,
attention_dropout_p=0.0,
dropout_p=0.1,
mha_output_proj_bias=True,
use_swiglu=True,
layer_normalization_style="rms",
pre_denoiser=ProjectionConfig(),
post_denoiser=ProjectionConfig(),
),
sonar_normalizer_name="sonar_normalizer_wikipedia_en_1m",
trained_with_cf_guidance=True,
noise_scheduler=DDIMSchedulerConfig(num_diffusion_train_steps=100),
)
@lcm_arch("two_tower_diffusion_lcm_1_6B")
def two_tower_diffusion_lcm_1_6B() -> TwoTowerDiffusionLCModelConfig:
"""5-layer encodder / 13-layer denoiser / model dim 2048
Parameter Size: 1,635,101,696"""
model_dim: int = 2048
num_attn_heads: int = 16
return TwoTowerDiffusionLCModelConfig(
model_dim=model_dim,
max_seq_len=4096,
frontend=EncoderFrontendConfig(),
context_encoder=TransformerConfig(
num_layers=5,
ffn_inner_dim=4 * model_dim,
num_attn_heads=num_attn_heads,
final_dropout_p=0.0,
attention_dropout_p=0.0,
dropout_p=0.1,
mha_output_proj_bias=True,
use_swiglu=True,
layer_normalization_style="rms",
pos_embedding_style="rope",
),
denoiser=DenoiserConfig(
num_layers=13,
timestep_embed_dim=model_dim,
ffn_inner_dim=4 * model_dim,
pos_embedding_style="none",
num_attn_heads=num_attn_heads,
final_dropout_p=0.0,
attention_dropout_p=0.0,
dropout_p=0.1,
mha_output_proj_bias=True,
use_swiglu=True,
layer_normalization_style="rms",
pre_denoiser=ProjectionConfig(),
post_denoiser=ProjectionConfig(),
),
# TODO change normalizer name to align with the normalizer instructions
sonar_normalizer_name="dummy_sonar_normalizer_B",
trained_with_cf_guidance=True,
noise_scheduler=DDIMSchedulerConfig(num_diffusion_train_steps=100),
)
@lcm_arch("two_tower_diffusion_lcm_7B")
def two_tower_diffusion_lcm_7B() -> TwoTowerDiffusionLCModelConfig:
# 5-layer encodder / 14-layer denoiser / model dim 4096
# Parameter Size: 6,930,781,696
model_dim: int = 4096
num_attn_heads: int = 32
return TwoTowerDiffusionLCModelConfig(
model_dim=model_dim,
max_seq_len=4096,
frontend=EncoderFrontendConfig(),
context_encoder=TransformerConfig(
num_layers=5,
ffn_inner_dim=4 * model_dim,
num_attn_heads=num_attn_heads,
final_dropout_p=0.0,
attention_dropout_p=0.0,
dropout_p=0.1,
mha_output_proj_bias=True,
use_swiglu=True,
layer_normalization_style="rms",
pos_embedding_style="rope",
),
denoiser=DenoiserConfig(
num_layers=14,
timestep_embed_dim=model_dim,
ffn_inner_dim=4 * model_dim,
pos_embedding_style="none",
num_attn_heads=num_attn_heads,
final_dropout_p=0.0,
attention_dropout_p=0.0,
dropout_p=0.1,
mha_output_proj_bias=True,
use_swiglu=True,
layer_normalization_style="rms",
pre_denoiser=ProjectionConfig(),
post_denoiser=ProjectionConfig(),
),
# TODO change normalizer name to align with the normalizer instructions
sonar_normalizer_name="dummy_sonar_normalizer_C",
trained_with_cf_guidance=True,
noise_scheduler=DDIMSchedulerConfig(num_diffusion_train_steps=100),
)