TerraTorch
blumenstiel commited on
Commit
56be098
·
1 Parent(s): 2918b77

Update ReadMe

Browse files
Files changed (2) hide show
  1. README.md +65 -0
  2. terramind_v1_base_impactmesh_fire.yaml +123 -0
README.md CHANGED
@@ -1,3 +1,68 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ library_name: terratorch
4
+ datasets:
5
+ - ibm-esa-geospatial/ImpactMesh-Fire
6
+ base_model:
7
+ - ibm-esa-geospatial/TerraMind-1.0-base
8
  ---
9
+
10
+ [![arXiv](https://img.shields.io/badge/arXiv-comming_soon-b31b1b?logo=arxiv)](https://arxiv.org/abs/todo)
11
+ [![Code](https://img.shields.io/badge/GitHub-ImpactMesh-EE4B2B?logo=github)](https://github.com/IBM/ImpactMesh)
12
+ [![IBMblog](https://img.shields.io/badge/Blog-IBM-0F62FE)](https://research.ibm.com/blog/todo)
13
+
14
+ # TerraMind-base-Fire
15
+
16
+ TerraMind-base-Fire is based on [TerraMind-1.0-base](https://huggingface.co/ibm-esa-geospatial/TerraMind-1.0-base) and was fine-tuned on [ImpactMesh-Fire](https://huggingface.co/datasets/ibm-esa-geospatial/ImpactMesh-FLood) using [TerraTorch](https://terrastackai.github.io/terratorch/stable/).
17
+ We use the [Temporal Wrapper](https://terrastackai.github.io/terratorch/stable/guide/temporal_wrapper/) for a mid-fusion approach. The backbone processes the multimodal input while the decoder fuses the multi-temporal information.
18
+ We refer to our technical report for details (coming soon!).
19
+
20
+ ## Usage
21
+
22
+ Quickstart with installing TerraTorch and the ImpactMesh DataModules or download them from [GitHub](https://github.com/IBM/ImpactMesh):
23
+ ```shell
24
+ pip install git+https://github.com/terrastackai/terratorch.git@multimodal
25
+ pip install impactmesh
26
+ ```
27
+
28
+ Load the model via Lightning:
29
+
30
+ ```python
31
+ import torch
32
+ from terratorch.cli_tools import LightningInferenceModel
33
+
34
+ # Load TerraTorch task from
35
+ task = LightningInferenceModel.from_config(
36
+ "terramind_v1_base_impactmesh_fire.yaml",
37
+ "TerraMind_v1_base_ImpactMesh_fire.pt",
38
+ )
39
+
40
+ model = task.model.model # Get model from Lighting task
41
+ model.eval()
42
+
43
+ # Inputs with shape [B, C, T, H, W]
44
+ input = {
45
+ "S2L2A": torch.randn([1, 12, 4, 256, 256]),
46
+ "S1RTC": torch.randn([1, 2, 4, 256, 256]),
47
+ "DEM": torch.randn([1, 1, 4, 256, 256]), # Repeated per timestep
48
+ }
49
+
50
+ # Run inference
51
+ with torch.no_grad():
52
+ pred = model(input).output
53
+
54
+ y_hat = pred.argmax(dim=1)
55
+ ```
56
+
57
+ Run predictions via the TerraTorch CLI:
58
+
59
+ ```shell
60
+ terratorch predict -c "terramind_v1_base_impactmesh_fire.yaml" --ckpt "TerraMind_v1_base_ImpactMesh_fire.pt" --predict_output_dir output/impactmesh_fire_predictions --predict_data_root "path/to/data/"
61
+ ```
62
+
63
+ For prediction, the ImpactMesh data module expects a format similar to the training data with subfolders per modality and zarr.zip and tif files.
64
+ Alternatively, you can adapt this [inference code](https://github.com/IBM/ImpactMesh/blob/main/impactmesh/run_inference.py).
65
+
66
+ ## Citation
67
+
68
+ Our technical report is released soon!
terramind_v1_base_impactmesh_fire.yaml ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # lightning.pytorch==2.1.1
2
+ seed_everything: 42
3
+ trainer:
4
+ accelerator: auto
5
+ strategy: auto
6
+ devices: auto
7
+ num_nodes: 1
8
+ precision: 16-mixed
9
+ # logger: true # Set to logger: true for Tensorboard
10
+ logger:
11
+ class_path: lightning.pytorch.loggers.WandbLogger
12
+ init_args:
13
+ project: ImpactMesh-Fire
14
+ name: TerraMind-Base_lr1e-4
15
+ config:
16
+ modalities: S1-S2-DEM
17
+ dataset: IM-Fire
18
+
19
+ callbacks:
20
+ - class_path: RichProgressBar
21
+ - class_path: LearningRateMonitor
22
+ init_args:
23
+ logging_interval: epoch
24
+ - class_path: EarlyStopping
25
+ init_args:
26
+ monitor: val/loss
27
+ patience: 10
28
+ - class_path: ModelCheckpoint
29
+ init_args:
30
+ monitor: val/loss
31
+ mode: min
32
+ save_weights_only: true
33
+ dirpath: output/terramind_base_impactmesh_fire/1e-4
34
+ filename: best_val_loss
35
+ max_epochs: 50
36
+ log_every_n_steps: 5
37
+ default_root_dir: output/terramind_base_impactmesh_fire/
38
+ data:
39
+ class_path: impactmesh.impactmesh_datamodule.ImpactMeshDataModule
40
+ init_args:
41
+ batch_size: 16
42
+ num_workers: 8
43
+ data_root: data/ImpactMesh-Fire/data
44
+ train_split: data/ImpactMesh-Fire/split/impactmesh_wildfire_train.txt
45
+ val_split: data/ImpactMesh-Fire/split/impactmesh_wildfire_val.txt
46
+ test_split: data/ImpactMesh-Fire/split/impactmesh_wildfire_test.txt
47
+ timesteps: [0, 1, 2, 3]
48
+ modalities:
49
+ - S2L2A
50
+ - S1RTC
51
+ - DEM
52
+ no_data_replace: 0
53
+ train_transform:
54
+ - class_path: terratorch.datasets.transforms.FlattenTemporalIntoChannels
55
+ - class_path: albumentations.D4
56
+ - class_path: albumentations.pytorch.ToTensorV2
57
+ - class_path: terratorch.datasets.transforms.UnflattenTemporalFromChannels
58
+ init_args:
59
+ n_timesteps: 4
60
+ # Use pretraining stats with frozen encoder
61
+ # means:
62
+ # S2L2A: [ 1390.458, 1503.317, 1718.197, 1853.91, 2199.1, 2779.975, 2987.011, 3083.234, 3132.22, 3162.988, 2424.884, 1857.648 ]
63
+ # S1RTC: [ -10.93, -17.329 ]
64
+ # DEM: [ 670.665 ]
65
+ # stds:
66
+ # S2L2A: [ 2106.761, 2141.107, 2038.973, 2134.138, 2085.321, 1889.926, 1820.257, 1871.918, 1753.829, 1797.379, 1434.261, 1334.311 ]
67
+ # S1RTC: [ 4.391, 4.459 ]
68
+ # DEM: [ 951.272 ]
69
+
70
+ model:
71
+ class_path: terratorch.tasks.SemanticSegmentationTask
72
+ init_args:
73
+ model_factory: EncoderDecoderFactory
74
+ model_args:
75
+ backbone: terramind_v1_base
76
+ backbone_pretrained: true
77
+ backbone_modalities:
78
+ - S2L2A
79
+ - S1RTC
80
+ - DEM
81
+
82
+ # Apply temporal wrapper (docs: https://terrastackai.github.io/terratorch/stable/guide/temporal_wrapper/)
83
+ backbone_use_temporal: true
84
+ backbone_temporal_pooling: concat
85
+ backbone_temporal_n_timestamps: 4
86
+
87
+ necks:
88
+ - name: SelectIndices
89
+ indices: [2, 5, 8, 11] # indices for terramind_v1_tiny, small, and base
90
+ # indices: [5, 11, 17, 23] # large version
91
+ - name: ReshapeTokensToImage
92
+ remove_cls_token: False
93
+ - name: LearnedInterpolateToPyramidal
94
+
95
+ decoder: UNetDecoder
96
+ # decoder_channels: [256, 128, 64, 32] # tiny
97
+ decoder_channels: [512, 256, 128, 64] # base
98
+
99
+ head_dropout: 0.1
100
+ num_classes: 2
101
+ loss: dice
102
+ ignore_index: -1
103
+ freeze_backbone: false
104
+ freeze_decoder: false
105
+ class_weights: [0.342, 1.316]
106
+ # For prediction: overlap of 16 pixel on each side, 8 pixels dropped
107
+ tiled_inference_parameters:
108
+ crop: 256
109
+ stride: 208
110
+ batch_size: 64
111
+ delta: 8
112
+
113
+ optimizer:
114
+ class_path: torch.optim.AdamW
115
+ init_args:
116
+ lr: 1.e-4
117
+ lr_scheduler:
118
+ class_path: ReduceLROnPlateau
119
+ init_args:
120
+ monitor: val/loss
121
+ factor: 0.5
122
+ patience: 2
123
+