Commit ·
56be098
1
Parent(s): 2918b77
Update ReadMe
Browse files- README.md +65 -0
- terramind_v1_base_impactmesh_fire.yaml +123 -0
README.md
CHANGED
|
@@ -1,3 +1,68 @@
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
+
library_name: terratorch
|
| 4 |
+
datasets:
|
| 5 |
+
- ibm-esa-geospatial/ImpactMesh-Fire
|
| 6 |
+
base_model:
|
| 7 |
+
- ibm-esa-geospatial/TerraMind-1.0-base
|
| 8 |
---
|
| 9 |
+
|
| 10 |
+
[](https://arxiv.org/abs/todo)
|
| 11 |
+
[](https://github.com/IBM/ImpactMesh)
|
| 12 |
+
[](https://research.ibm.com/blog/todo)
|
| 13 |
+
|
| 14 |
+
# TerraMind-base-Fire
|
| 15 |
+
|
| 16 |
+
TerraMind-base-Fire is based on [TerraMind-1.0-base](https://huggingface.co/ibm-esa-geospatial/TerraMind-1.0-base) and was fine-tuned on [ImpactMesh-Fire](https://huggingface.co/datasets/ibm-esa-geospatial/ImpactMesh-FLood) using [TerraTorch](https://terrastackai.github.io/terratorch/stable/).
|
| 17 |
+
We use the [Temporal Wrapper](https://terrastackai.github.io/terratorch/stable/guide/temporal_wrapper/) for a mid-fusion approach. The backbone processes the multimodal input while the decoder fuses the multi-temporal information.
|
| 18 |
+
We refer to our technical report for details (coming soon!).
|
| 19 |
+
|
| 20 |
+
## Usage
|
| 21 |
+
|
| 22 |
+
Quickstart with installing TerraTorch and the ImpactMesh DataModules or download them from [GitHub](https://github.com/IBM/ImpactMesh):
|
| 23 |
+
```shell
|
| 24 |
+
pip install git+https://github.com/terrastackai/terratorch.git@multimodal
|
| 25 |
+
pip install impactmesh
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
Load the model via Lightning:
|
| 29 |
+
|
| 30 |
+
```python
|
| 31 |
+
import torch
|
| 32 |
+
from terratorch.cli_tools import LightningInferenceModel
|
| 33 |
+
|
| 34 |
+
# Load TerraTorch task from
|
| 35 |
+
task = LightningInferenceModel.from_config(
|
| 36 |
+
"terramind_v1_base_impactmesh_fire.yaml",
|
| 37 |
+
"TerraMind_v1_base_ImpactMesh_fire.pt",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
model = task.model.model # Get model from Lighting task
|
| 41 |
+
model.eval()
|
| 42 |
+
|
| 43 |
+
# Inputs with shape [B, C, T, H, W]
|
| 44 |
+
input = {
|
| 45 |
+
"S2L2A": torch.randn([1, 12, 4, 256, 256]),
|
| 46 |
+
"S1RTC": torch.randn([1, 2, 4, 256, 256]),
|
| 47 |
+
"DEM": torch.randn([1, 1, 4, 256, 256]), # Repeated per timestep
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
# Run inference
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
pred = model(input).output
|
| 53 |
+
|
| 54 |
+
y_hat = pred.argmax(dim=1)
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
Run predictions via the TerraTorch CLI:
|
| 58 |
+
|
| 59 |
+
```shell
|
| 60 |
+
terratorch predict -c "terramind_v1_base_impactmesh_fire.yaml" --ckpt "TerraMind_v1_base_ImpactMesh_fire.pt" --predict_output_dir output/impactmesh_fire_predictions --predict_data_root "path/to/data/"
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
For prediction, the ImpactMesh data module expects a format similar to the training data with subfolders per modality and zarr.zip and tif files.
|
| 64 |
+
Alternatively, you can adapt this [inference code](https://github.com/IBM/ImpactMesh/blob/main/impactmesh/run_inference.py).
|
| 65 |
+
|
| 66 |
+
## Citation
|
| 67 |
+
|
| 68 |
+
Our technical report is released soon!
|
terramind_v1_base_impactmesh_fire.yaml
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# lightning.pytorch==2.1.1
|
| 2 |
+
seed_everything: 42
|
| 3 |
+
trainer:
|
| 4 |
+
accelerator: auto
|
| 5 |
+
strategy: auto
|
| 6 |
+
devices: auto
|
| 7 |
+
num_nodes: 1
|
| 8 |
+
precision: 16-mixed
|
| 9 |
+
# logger: true # Set to logger: true for Tensorboard
|
| 10 |
+
logger:
|
| 11 |
+
class_path: lightning.pytorch.loggers.WandbLogger
|
| 12 |
+
init_args:
|
| 13 |
+
project: ImpactMesh-Fire
|
| 14 |
+
name: TerraMind-Base_lr1e-4
|
| 15 |
+
config:
|
| 16 |
+
modalities: S1-S2-DEM
|
| 17 |
+
dataset: IM-Fire
|
| 18 |
+
|
| 19 |
+
callbacks:
|
| 20 |
+
- class_path: RichProgressBar
|
| 21 |
+
- class_path: LearningRateMonitor
|
| 22 |
+
init_args:
|
| 23 |
+
logging_interval: epoch
|
| 24 |
+
- class_path: EarlyStopping
|
| 25 |
+
init_args:
|
| 26 |
+
monitor: val/loss
|
| 27 |
+
patience: 10
|
| 28 |
+
- class_path: ModelCheckpoint
|
| 29 |
+
init_args:
|
| 30 |
+
monitor: val/loss
|
| 31 |
+
mode: min
|
| 32 |
+
save_weights_only: true
|
| 33 |
+
dirpath: output/terramind_base_impactmesh_fire/1e-4
|
| 34 |
+
filename: best_val_loss
|
| 35 |
+
max_epochs: 50
|
| 36 |
+
log_every_n_steps: 5
|
| 37 |
+
default_root_dir: output/terramind_base_impactmesh_fire/
|
| 38 |
+
data:
|
| 39 |
+
class_path: impactmesh.impactmesh_datamodule.ImpactMeshDataModule
|
| 40 |
+
init_args:
|
| 41 |
+
batch_size: 16
|
| 42 |
+
num_workers: 8
|
| 43 |
+
data_root: data/ImpactMesh-Fire/data
|
| 44 |
+
train_split: data/ImpactMesh-Fire/split/impactmesh_wildfire_train.txt
|
| 45 |
+
val_split: data/ImpactMesh-Fire/split/impactmesh_wildfire_val.txt
|
| 46 |
+
test_split: data/ImpactMesh-Fire/split/impactmesh_wildfire_test.txt
|
| 47 |
+
timesteps: [0, 1, 2, 3]
|
| 48 |
+
modalities:
|
| 49 |
+
- S2L2A
|
| 50 |
+
- S1RTC
|
| 51 |
+
- DEM
|
| 52 |
+
no_data_replace: 0
|
| 53 |
+
train_transform:
|
| 54 |
+
- class_path: terratorch.datasets.transforms.FlattenTemporalIntoChannels
|
| 55 |
+
- class_path: albumentations.D4
|
| 56 |
+
- class_path: albumentations.pytorch.ToTensorV2
|
| 57 |
+
- class_path: terratorch.datasets.transforms.UnflattenTemporalFromChannels
|
| 58 |
+
init_args:
|
| 59 |
+
n_timesteps: 4
|
| 60 |
+
# Use pretraining stats with frozen encoder
|
| 61 |
+
# means:
|
| 62 |
+
# S2L2A: [ 1390.458, 1503.317, 1718.197, 1853.91, 2199.1, 2779.975, 2987.011, 3083.234, 3132.22, 3162.988, 2424.884, 1857.648 ]
|
| 63 |
+
# S1RTC: [ -10.93, -17.329 ]
|
| 64 |
+
# DEM: [ 670.665 ]
|
| 65 |
+
# stds:
|
| 66 |
+
# S2L2A: [ 2106.761, 2141.107, 2038.973, 2134.138, 2085.321, 1889.926, 1820.257, 1871.918, 1753.829, 1797.379, 1434.261, 1334.311 ]
|
| 67 |
+
# S1RTC: [ 4.391, 4.459 ]
|
| 68 |
+
# DEM: [ 951.272 ]
|
| 69 |
+
|
| 70 |
+
model:
|
| 71 |
+
class_path: terratorch.tasks.SemanticSegmentationTask
|
| 72 |
+
init_args:
|
| 73 |
+
model_factory: EncoderDecoderFactory
|
| 74 |
+
model_args:
|
| 75 |
+
backbone: terramind_v1_base
|
| 76 |
+
backbone_pretrained: true
|
| 77 |
+
backbone_modalities:
|
| 78 |
+
- S2L2A
|
| 79 |
+
- S1RTC
|
| 80 |
+
- DEM
|
| 81 |
+
|
| 82 |
+
# Apply temporal wrapper (docs: https://terrastackai.github.io/terratorch/stable/guide/temporal_wrapper/)
|
| 83 |
+
backbone_use_temporal: true
|
| 84 |
+
backbone_temporal_pooling: concat
|
| 85 |
+
backbone_temporal_n_timestamps: 4
|
| 86 |
+
|
| 87 |
+
necks:
|
| 88 |
+
- name: SelectIndices
|
| 89 |
+
indices: [2, 5, 8, 11] # indices for terramind_v1_tiny, small, and base
|
| 90 |
+
# indices: [5, 11, 17, 23] # large version
|
| 91 |
+
- name: ReshapeTokensToImage
|
| 92 |
+
remove_cls_token: False
|
| 93 |
+
- name: LearnedInterpolateToPyramidal
|
| 94 |
+
|
| 95 |
+
decoder: UNetDecoder
|
| 96 |
+
# decoder_channels: [256, 128, 64, 32] # tiny
|
| 97 |
+
decoder_channels: [512, 256, 128, 64] # base
|
| 98 |
+
|
| 99 |
+
head_dropout: 0.1
|
| 100 |
+
num_classes: 2
|
| 101 |
+
loss: dice
|
| 102 |
+
ignore_index: -1
|
| 103 |
+
freeze_backbone: false
|
| 104 |
+
freeze_decoder: false
|
| 105 |
+
class_weights: [0.342, 1.316]
|
| 106 |
+
# For prediction: overlap of 16 pixel on each side, 8 pixels dropped
|
| 107 |
+
tiled_inference_parameters:
|
| 108 |
+
crop: 256
|
| 109 |
+
stride: 208
|
| 110 |
+
batch_size: 64
|
| 111 |
+
delta: 8
|
| 112 |
+
|
| 113 |
+
optimizer:
|
| 114 |
+
class_path: torch.optim.AdamW
|
| 115 |
+
init_args:
|
| 116 |
+
lr: 1.e-4
|
| 117 |
+
lr_scheduler:
|
| 118 |
+
class_path: ReduceLROnPlateau
|
| 119 |
+
init_args:
|
| 120 |
+
monitor: val/loss
|
| 121 |
+
factor: 0.5
|
| 122 |
+
patience: 2
|
| 123 |
+
|