deepmirt / config.yaml
liuliu2333's picture
Upload config.yaml with huggingface_hub
e72be62 verified
# ============================================================
# miRNA-target prediction model — default training configuration
# ============================================================
# This file controls all hyperparameters for the training pipeline.
# Modify this file to tune parameters without changing the code.
#
# Usage:
# python insect_mirna_target/training/train.py --config insect_mirna_target/configs/default.yaml
# python insect_mirna_target/training/train.py --config insect_mirna_target/configs/default.yaml --fast-dev-run
# ============================================================
# ---- Model architecture ----
model:
freeze_backbone: true # Whether to freeze the RNA-FM backbone (Phase 1 frozen, Phase 2 unfrozen)
cross_attn_heads: 8 # Number of Cross-Attention heads (8 heads x 80 dim = 640)
cross_attn_layers: 2 # Number of stacked Cross-Attention layers (Mimosa uses 16, our lightweight version uses 2)
classifier_hidden: # MLP classifier head hidden layer dimensions
- 256 # First layer: 640->256
- 64 # Second layer: 256->64, final output 64->1
dropout: 0.3 # Dropout rate (0.2-0.5 range, 0.3 is a common default)
# ---- Data loading ----
data:
data_dir: "insect_mirna_target/data/training" # Training data directory (contains train.csv, val.csv, test.csv)
batch_size: 128 # Batch size (128-256 recommended, reduce to 64 if GPU memory is insufficient)
num_workers: 8 # DataLoader worker processes (typically set to 1/4 of CPU cores)
pin_memory: true # Pinned memory (recommended true for GPU training, accelerates CPU->GPU data transfer)
# ---- Training hyperparameters ----
training:
lr: 1.0e-4 # Base learning rate (used for classifier head; backbone is multiplied by 0.01)
weight_decay: 1.0e-5 # L2 regularization coefficient (prevents overfitting, 1e-5 to 1e-4 range)
scheduler: "cosine" # Learning rate scheduler: cosine | onecycle | none
max_epochs: 30 # Maximum training epochs (typically 20-50, combined with early stopping)
gradient_clip_val: 1.0 # Gradient clipping threshold (prevents gradient explosion, 1.0 is standard)
accumulate_grad_batches: 1 # Gradient accumulation steps (set to 2-4 to simulate larger batch if GPU memory is insufficient)
precision: "16-mixed" # Mixed precision training (reduces GPU memory, speeds up computation, with negligible precision loss)
# ---- Progressive unfreezing ----
unfreezing:
enabled: false # Whether to enable progressive unfreezing (Phase 1 set false, Phase 2 set true)
unfreeze_at_epoch: 5 # Epoch at which to unfreeze the first layer (typically after classifier head converges)
num_layers: 3 # Total number of top RNA-FM layers to unfreeze (3 layers = layers 10, 11, 12)
unfreeze_interval: 3 # Unfreeze the next layer every N epochs (epoch 5->L12, 8->L11, 11->L10)
warmup_epochs: 1 # Number of backbone lr warmup epochs after each unfreeze (lr linearly ramps up from 1/10)
# ---- Trainer configuration ----
trainer:
accelerator: "gpu" # Accelerator type: gpu | cpu
devices: 2 # Number of GPUs (this machine has 2x NVIDIA L20)
strategy: "ddp" # Distributed strategy: ddp (data parallel) | fsdp (fully sharded data parallel)
# ---- Checkpoint management ----
checkpointing:
monitor: "val_auroc" # Metric to monitor (criterion for selecting the best model)
mode: "max" # max=higher is better (auroc), min=lower is better (loss)
save_top_k: 3 # Save top-K best models (disk space limited, keep only 3)
save_last: true # Whether to additionally save the last epoch's checkpoint
dirpath: "checkpoints/" # Checkpoint save directory
# ---- Logging configuration ----
logging:
logger: "tensorboard" # Logging backend: tensorboard (no login required, view locally)
log_dir: "lightning_logs/" # TensorBoard log directory
log_every_n_steps: 50 # Log every N steps (too frequent will slow down training)
# ---- Early stopping ----
early_stopping:
enabled: true # Whether to enable early stopping (prevents overfitting, recommended)
monitor: "val_loss" # Metric to monitor
patience: 10 # Stop after N consecutive epochs without improvement (progressive unfreezing needs more recovery time)
mode: "min" # min=lower is better
# ---- Random seed ----
seed: 42 # Global random seed (ensures experiment reproducibility)