GridAR-3B-T2I / config.yaml
ZhitongGao's picture
Update README sections and remove deprecated checkpoint/base_model fields
ce980b9
# GridTok (2D) AR d36 - Text-to-Image (T2I) Inference Configuration (No MuP)
# Model: 36 blocks, embed_dim=2304, 36 heads
# Uses 2D grid tokenizer (16x16) instead of 1D (256x1)
# Uses nomup checkpoint (muP scaling baked into weights)
model:
modality: image
# For 2D grid, the AR checkpoint does NOT contain image_tokenizer weights.
# The tokenizer is loaded from a separate HuggingFace repo.
remove_image_tokenizer: true
meta_model:
_target_: l3m.model.meta_models.InSeriesMetaModels
models:
image_tokenizer:
_target_: flextok_ar.model.integration.ImageResamplerTokenizer
# 2D grid tokenizer hosted separately on HuggingFace
model_id: ZhitongGao/GridAR_256
force_vae_encode: true
sample_posterior: true
image_dims: [256, 256]
read_key: image
write_key: image_token_ids
vae_latent_read_key: vae_latents
del_decoders: false
token_grid_size: [16, 16]
text_encoder:
_target_: flextok_ar.model.text_encoder.T5EmbedderWithMLP
t5_embedder:
_target_: flextok_ar.model.text_encoder.T5Embedder
read_key: text
text_embeddings_write_key: text_embeddings
text_embeddings_mask_write_key: cross_attn_mask
hf_hub_path: google/flan-t5-xl
encoder_seqlen_max: 128
decoder_seqlen: 256
cond_dropout_p: 0.1
mlp:
_target_: flextok_ar.model.text_encoder.TextToEmbedMLP
text_dim: 2048
embed_dim: 2304
act_layer:
_target_: torch.nn.GELU
_partial_: true
approximate: tanh
use_bias: false
ar_image_model:
_target_: l3m.model.meta_models.MetaModel
preprocessor:
_target_: flextok_ar.model.preprocessors.ARImageEmbedPreprocessor
read_key: image_token_ids
write_key:
- input_embeddings
- target_token_ids
inference_read_key: pred_image_token_ids
token_grid_size: [16, 16]
codebook_size: 64000
embed_dim: 2304
pos_embed_type: absolute
num_classes: null
trunk:
_target_: l3m.model.trunks.transformer_decoder.TransformerDecoder
read_key: input_embeddings
write_key: output_embeddings
encoder_output_key: text_embeddings
self_attn_mask_read_key: cross_attn_mask
embed_dim: 2304
num_blocks: 36
mlp_ratio: 4
norm_layer:
_target_: l3m.model.layers.normalization.LayerNormFP32
_partial_: true
eps: 1.0e-05
ffn_target:
_target_: l3m.model.layers.ffn.SwiGLUFFN
_partial_: true
self_attn_target:
_target_: l3m.model.layers.attention.EfficientAttention
_partial_: true
dim: 2304
num_heads: 36
qkv_bias: false
is_causal: true
qk_norm:
_target_: l3m.model.layers.normalization.LayerNormFP32
_partial_: true
eps: 1.0e-05
rope_pos_embed: null
cross_attn_target:
_target_: flextok_ar.model.attention.GeneralizedAttentionWithMask
_partial_: true
dim: 2304
encoder_dim: 2304
num_heads: 36
qkv_bias: false
is_causal: false
qk_norm:
_target_: l3m.model.layers.normalization.LayerNormFP32
_partial_: true
eps: 1.0e-05
weight_init_style: jax
post_trunk_norm: true
use_bias: false
postprocessor:
_target_: torch.nn.Identity
head:
_target_: l3m.model.heads.classifier.LinearClassifier
read_key: output_embeddings
write_key: token_preds
in_features: 2304
out_features: 64000
generation:
model_type: ar_text_to_image_model
sample: true
temperature: 1.0
top_k: 0
top_p: 0.0
cfg_factor: 3.0
num_keep_tokens: 256
num_samples: 1
timesteps: 25
tokenizer_cfg_factor: 5.0
tokenizer_perform_norm_guidance: true
# VAE image sizes for 2D grid decode (f8 VAE: 256/8 = 32)
decode:
vae_image_sizes: 32
image:
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
size: 256
device: cuda
seed: 42