Lanni-ni's picture
add remote code + model files
0fe45e5 verified
model:
_target_: forgetting_transformer.model.forgetting_transformer.modeling_forgetting_transformer.ForgettingTransformerForCausalLM
config:
_target_: forgetting_transformer.model.forgetting_transformer.configuration_forgetting_transformer.ForgettingTransformerConfig
vocab_size: ???
hidden_size: 256
hidden_ratio: 4.0
intermediate_size: null
num_hidden_layers: 3
num_heads: 4
num_kv_heads: null
hidden_act: swish
window_size: null
max_position_embeddings: null
initializer_range: 0.02
elementwise_affine: true
norm_eps: 1.0e-06
use_cache: true
pad_token_id: null
bos_token_id: null
eos_token_id: null
tie_word_embeddings: false
attention_bias: false
fuse_norm: true
fuse_cross_entropy: true
rope_base: 500000.0
use_rope: false
use_output_gate: false
ogate_act: sigmoid
fgate_type: full
fgate_bias_init: false
decay_time_min: null
decay_time_max: null
use_output_norm: false
qk_norm: false
qk_norm_share_param_across_head: false
use_k_shift: false
use_v_shift: false
optimizer:
_target_: torch.optim.AdamW
lr: 0.001
betas:
- 0.9
- 0.95
weight_decay: 0.1
schedule:
_target_: forgetting_transformer.schedule.warmup_cosine_decay_schedule
init_value: 0.0
peak_value: 0.001
warmup_steps: 20971520
decay_steps: 2097152000
end_value: 0.0
datamodule:
_target_: forgetting_transformer.datamodule.npy.NpyDataModule
data_path: /workspace/forgetting-transformer/data
rank: ???
world_size: ???
train_batch_len: 2048
train_batch_size: 1024
train_num_workers: 0
eval_tokens: 2147483648
eval_batch_len: 2048
eval_local_batch_size: 1
eval_num_workers: 0
strategy:
_target_: lightning.fabric.strategies.SingleDeviceStrategy
device: cuda:0
exp: forgetting_gate_3_4_256
tag: forgetting_gate_3_4_256
seed: 42
hf_load_dir: null
hf_save_dir: null
hf_load_step: null
output_dir: /workspace/forgetting-transformer/forgetting_gate_3_4_256
data_dir: /workspace/forgetting-transformer/data
resume: false
fork_dir: null
fork_step: null
log_interval: 20971520
eval_interval: 41943040
final_eval: true
skip_eval: false
checkpoint_interval: 209715200
train_eval_interval: 104857600
checkpoint_keep_interval: 209715200
fabric:
devices: 1
precision: 16-mixed
train:
max_tokens: 2097152000
grad_acc_tokens: 32768
max_grad_norm: 1.0
gradient_checkpointing: true
bias_weight_decay: false
normalization_weight_decay: false
conv_weight_decay: true
eval:
min_val_length: 512
wandb:
project: forgetting-transformer
mode: online
log_dir: ./output/wandb