CVD / FTTransformerModel /config.yml
yonghan93's picture
Update FTTransformerModel/config.yml
2ee0e24 verified
raw
history blame
2.29 kB
target:
- cardio
continuous_cols:
- age
- height
- weight
- ap_hi
- ap_lo
- bmi
- pulse_pressure
categorical_cols:
- gender
- cholesterol
- gluc
- smoke
- alco
- active
- hypertension
date_columns: []
encode_date_columns: true
validation_split: 0.2
continuous_feature_transform: null
normalize_continuous_features: true
quantile_noise: 0
num_workers: 0
pin_memory: true
handle_unknown_categories: true
handle_missing_values: true
task: classification
head: LinearHead
head_config:
layers: ''
embedding_dims: null
embedding_dropout: 0.1
batch_norm_continuous_input: true
learning_rate: 0.0006
loss: CrossEntropyLoss
metrics:
- accuracy
metrics_prob_input:
- false
metrics_params:
- {}
target_range: null
virtual_batch_size: null
seed: 42
_module_src: models.ft_transformer
_model_name: FTTransformerModel
_backbone_name: FTTransformerBackbone
_config_name: FTTransformerConfig
input_embed_dim: 32
embedding_initialization: kaiming_uniform
embedding_bias: true
share_embedding: false
share_embedding_strategy: fraction
shared_embedding_fraction: 0.25
attn_feature_importance: true
num_heads: 8
num_attn_blocks: 3
transformer_head_dim: null
attn_dropout: 0.1
add_norm_dropout: 0.1
ff_dropout: 0.1
ff_hidden_multiplier: 4
transformer_activation: GEGLU
batch_size: 512
data_aware_init_batch_size: 2000
fast_dev_run: false
max_epochs: 100
min_epochs: 1
max_time: null
accelerator: auto
devices: -1
devices_list: null
accumulate_grad_batches: 1
auto_lr_find: false
auto_select_gpus: true
check_val_every_n_epoch: 1
gradient_clip_val: 0.0
overfit_batches: 0.0
deterministic: false
profiler: null
early_stopping: valid_loss
early_stopping_min_delta: 0.001
early_stopping_mode: min
early_stopping_patience: 10
early_stopping_kwargs: {}
checkpoints: valid_loss
checkpoints_path: saved_models
checkpoints_every_n_epochs: 1
checkpoints_name: null
checkpoints_mode: min
checkpoints_save_top_k: 1
checkpoints_kwargs: {}
load_best: true
track_grad_norm: -1
progress_bar: rich
precision: 32
trainer_kwargs: {}
project_name: CVD_FTTransformer
run_name: run_01
exp_watch: null
log_target: tensorboard
log_logits: false
exp_log_freq: 100
optimizer: Adam
optimizer_params: {}
lr_scheduler: null
lr_scheduler_params: {}
lr_scheduler_monitor_metric: valid_loss
categorical_dim: 7
enable_checkpointing: true