File size: 2,849 Bytes
97a17c2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | # Global Seed for reproducibility (matches script's set_seed)
seed_everything: 42
# ------------------------------------------------------------------
# Trainer Configuration
# ------------------------------------------------------------------
trainer:
accelerator: auto # Handles "cuda" if available, else "cpu"
strategy: auto
devices: 1
max_epochs: 100 # Matches args.epochs
default_root_dir: ./checkpoints
# Callbacks to replicate the script's checkpointing and logging
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val/loss
mode: min
save_top_k: 1
filename: "best_model"
save_last: true # Saves 'last.ckpt' (similar to 'final_model.pth')
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: epoch
# ------------------------------------------------------------------
# Model Configuration (TerraMind + UperNet)
# ------------------------------------------------------------------
model:
class_path: terratorch.tasks.ClassificationTask
init_args:
model_factory: EncoderDecoderFactory
loss: ce
ignore_index: -1
lr: 1.0e-5
# Optimizer settings matching _init_optimizer
optimizer: AdamW
optimizer_hparams:
weight_decay: 0.05
# Scheduler settings matching ReduceLROnPlateau
scheduler: ReduceLROnPlateau
scheduler_hparams:
mode: min
patience: 5
# --------------------------------------------------------------
# Model Architecture (Exact match to script's model_config)
# --------------------------------------------------------------
model_args:
backbone: terramind_v1_base
backbone_pretrained: true
backbone_modalities:
- S2L2A
backbone_merge_method: mean
decoder: UperNetDecoder
decoder_scale_modules: true
decoder_channels: 256
num_classes: 2
head_dropout: 0.3
# Specific neck configuration for TerraMind
necks:
- name: ReshapeTokensToImage
remove_cls_token: false
- name: SelectIndices
indices: [2, 5, 8, 11]
# ------------------------------------------------------------------
# Data Configuration
# ------------------------------------------------------------------
data:
class_path: methane_classification_datamodule.MethaneClassificationDataModule
init_args:
data_root: ../../MBD_nan_S2_zscore/MBD_nan_S2_zscore
excel_file: ../../Methane_benchmark_patches_summary_v3.xlsx
batch_size: 8
val_split: 0.2
seed: 42
# Note: The procedural train_test_split logic from the script
# (handling folds/splitting) should be encapsulated inside the
# DataModule's setup() method for this config to work seamlessly.
|