| # ============================================================ | |
| # miRNA-target prediction model — default training configuration | |
| # ============================================================ | |
| # This file controls all hyperparameters for the training pipeline. | |
| # Modify this file to tune parameters without changing the code. | |
| # | |
| # Usage: | |
| # python insect_mirna_target/training/train.py --config insect_mirna_target/configs/default.yaml | |
| # python insect_mirna_target/training/train.py --config insect_mirna_target/configs/default.yaml --fast-dev-run | |
| # ============================================================ | |
| # ---- Model architecture ---- | |
| model: | |
| freeze_backbone: true # Whether to freeze the RNA-FM backbone (Phase 1 frozen, Phase 2 unfrozen) | |
| cross_attn_heads: 8 # Number of Cross-Attention heads (8 heads x 80 dim = 640) | |
| cross_attn_layers: 2 # Number of stacked Cross-Attention layers (Mimosa uses 16, our lightweight version uses 2) | |
| classifier_hidden: # MLP classifier head hidden layer dimensions | |
| - 256 # First layer: 640->256 | |
| - 64 # Second layer: 256->64, final output 64->1 | |
| dropout: 0.3 # Dropout rate (0.2-0.5 range, 0.3 is a common default) | |
| # ---- Data loading ---- | |
| data: | |
| data_dir: "insect_mirna_target/data/training" # Training data directory (contains train.csv, val.csv, test.csv) | |
| batch_size: 128 # Batch size (128-256 recommended, reduce to 64 if GPU memory is insufficient) | |
| num_workers: 8 # DataLoader worker processes (typically set to 1/4 of CPU cores) | |
| pin_memory: true # Pinned memory (recommended true for GPU training, accelerates CPU->GPU data transfer) | |
| # ---- Training hyperparameters ---- | |
| training: | |
| lr: 1.0e-4 # Base learning rate (used for classifier head; backbone is multiplied by 0.01) | |
| weight_decay: 1.0e-5 # L2 regularization coefficient (prevents overfitting, 1e-5 to 1e-4 range) | |
| scheduler: "cosine" # Learning rate scheduler: cosine | onecycle | none | |
| max_epochs: 30 # Maximum training epochs (typically 20-50, combined with early stopping) | |
| gradient_clip_val: 1.0 # Gradient clipping threshold (prevents gradient explosion, 1.0 is standard) | |
| accumulate_grad_batches: 1 # Gradient accumulation steps (set to 2-4 to simulate larger batch if GPU memory is insufficient) | |
| precision: "16-mixed" # Mixed precision training (reduces GPU memory, speeds up computation, with negligible precision loss) | |
| # ---- Progressive unfreezing ---- | |
| unfreezing: | |
| enabled: false # Whether to enable progressive unfreezing (Phase 1 set false, Phase 2 set true) | |
| unfreeze_at_epoch: 5 # Epoch at which to unfreeze the first layer (typically after classifier head converges) | |
| num_layers: 3 # Total number of top RNA-FM layers to unfreeze (3 layers = layers 10, 11, 12) | |
| unfreeze_interval: 3 # Unfreeze the next layer every N epochs (epoch 5->L12, 8->L11, 11->L10) | |
| warmup_epochs: 1 # Number of backbone lr warmup epochs after each unfreeze (lr linearly ramps up from 1/10) | |
| # ---- Trainer configuration ---- | |
| trainer: | |
| accelerator: "gpu" # Accelerator type: gpu | cpu | |
| devices: 2 # Number of GPUs (this machine has 2x NVIDIA L20) | |
| strategy: "ddp" # Distributed strategy: ddp (data parallel) | fsdp (fully sharded data parallel) | |
| # ---- Checkpoint management ---- | |
| checkpointing: | |
| monitor: "val_auroc" # Metric to monitor (criterion for selecting the best model) | |
| mode: "max" # max=higher is better (auroc), min=lower is better (loss) | |
| save_top_k: 3 # Save top-K best models (disk space limited, keep only 3) | |
| save_last: true # Whether to additionally save the last epoch's checkpoint | |
| dirpath: "checkpoints/" # Checkpoint save directory | |
| # ---- Logging configuration ---- | |
| logging: | |
| logger: "tensorboard" # Logging backend: tensorboard (no login required, view locally) | |
| log_dir: "lightning_logs/" # TensorBoard log directory | |
| log_every_n_steps: 50 # Log every N steps (too frequent will slow down training) | |
| # ---- Early stopping ---- | |
| early_stopping: | |
| enabled: true # Whether to enable early stopping (prevents overfitting, recommended) | |
| monitor: "val_loss" # Metric to monitor | |
| patience: 10 # Stop after N consecutive epochs without improvement (progressive unfreezing needs more recovery time) | |
| mode: "min" # min=lower is better | |
| # ---- Random seed ---- | |
| seed: 42 # Global random seed (ensures experiment reproducibility) | |