File size: 2,849 Bytes
97a17c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Global Seed for reproducibility (matches script's set_seed)
seed_everything: 42

# ------------------------------------------------------------------
# Trainer Configuration
# ------------------------------------------------------------------
trainer:
  accelerator: auto       # Handles "cuda" if available, else "cpu"
  strategy: auto
  devices: 1
  max_epochs: 100         # Matches args.epochs
  default_root_dir: ./checkpoints
  
  # Callbacks to replicate the script's checkpointing and logging
  callbacks:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        monitor: val/loss
        mode: min
        save_top_k: 1
        filename: "best_model"
        save_last: true   # Saves 'last.ckpt' (similar to 'final_model.pth')
        
    - class_path: lightning.pytorch.callbacks.LearningRateMonitor
      init_args:
        logging_interval: epoch

# ------------------------------------------------------------------
# Model Configuration (TerraMind + UperNet)
# ------------------------------------------------------------------
model:
  class_path: terratorch.tasks.ClassificationTask
  init_args:
    model_factory: EncoderDecoderFactory
    loss: ce
    ignore_index: -1
    lr: 1.0e-5
    # Optimizer settings matching _init_optimizer
    optimizer: AdamW
    optimizer_hparams:
      weight_decay: 0.05
      
    # Scheduler settings matching ReduceLROnPlateau
    scheduler: ReduceLROnPlateau
    scheduler_hparams:
      mode: min
      patience: 5

    # --------------------------------------------------------------
    # Model Architecture (Exact match to script's model_config)
    # --------------------------------------------------------------
    model_args:
      backbone: terramind_v1_base
      backbone_pretrained: true
      backbone_modalities:
        - S2L2A
      backbone_merge_method: mean
      
      decoder: UperNetDecoder
      decoder_scale_modules: true
      decoder_channels: 256
      num_classes: 2
      head_dropout: 0.3
      
      # Specific neck configuration for TerraMind
      necks:
        - name: ReshapeTokensToImage
          remove_cls_token: false
        - name: SelectIndices
          indices: [2, 5, 8, 11]

# ------------------------------------------------------------------
# Data Configuration
# ------------------------------------------------------------------
data:
  class_path: methane_classification_datamodule.MethaneClassificationDataModule
  init_args:
    data_root: ../../MBD_nan_S2_zscore/MBD_nan_S2_zscore
    excel_file: ../../Methane_benchmark_patches_summary_v3.xlsx
    batch_size: 8
    val_split: 0.2
    seed: 42
    # Note: The procedural train_test_split logic from the script 
    # (handling folds/splitting) should be encapsulated inside the 
    # DataModule's setup() method for this config to work seamlessly.