DevPanda004 commited on
Commit
d43e46a
·
verified ·
1 Parent(s): 9aa50f7

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +59 -0
  2. config.yaml +182 -0
  3. model_inference.py +344 -0
  4. requirements.txt +6 -0
README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - Pytorch
7
+ - segmentation
8
+ - Flood mapping
9
+ - Sentinel-2
10
+ - Geospatial
11
+ - Foundation model
12
+ ---
13
+ ### Model and Inputs
14
+ The pretrained [Prithvi-EO-2.0-300M-TL](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL) model is finetuned to segment the extent of floods on Sentinel-2 images from the [Sen1Floods11 dataset](https://github.com/cloudtostreet/Sen1Floods11).
15
+
16
+ The dataset consists of 446 labeled 512x512 chips that span all 14 biomes, 357 ecoregions, and 6 continents of the world across 11 flood events. The benchmark associated to Sen1Floods11 provides results for fully convolutional neural networks trained in various input/labeled data setups, considering Sentinel-1 and Sentinel-2 imagery.
17
+
18
+ We use the following six bands for flood mapping: Blue, Green, Red, Narrow NIR, SWIR, SWIR 2.
19
+
20
+ Labels represent no water (class 0), water/flood (class 1), and no data/clouds (class -1).
21
+
22
+ The Prithvi-EO-2.0-300M-TL model was initially pretrained using a sequence length of 4 timestamps. Based on the characteristics of this benchmark dataset, we focus on single-timestamp segmentation. This demonstrates that our model can be utilized with an arbitrary number of timestamps during fine-tuning.
23
+
24
+ ### Fine-tuning
25
+
26
+ The model was fine-tuned using [TerraTorch](https://github.com/IBM/terratorch):
27
+
28
+ ```shell
29
+ terratorch fit -c sen1floods11.yaml
30
+ ```
31
+
32
+ The configuration used for finetuning is available through this [config](https://github.com/NASA-IMPACT/Prithvi-EO-2.0/blob/main/configs/sen1floods11.yaml).
33
+
34
+ ### Inference and demo
35
+
36
+ A **demo** running this model is available **[here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-EO-2.0-Sen1Floods11-demo)**.
37
+
38
+ This repo includes an inference script that allows running the flood model for inference on Sentinel-2 L1C images.
39
+
40
+ ```shell
41
+ python inference.py --data_file examples/India_900498_S2Hand.tif
42
+ ```
43
+
44
+ ### Feedback
45
+
46
+ Your feedback is invaluable to us. If you have any feedback about the model, please feel free to share it with us. You can do this by submitting issues on GitHub or start a discussion on HuggingFace.
47
+
48
+ ### Citation
49
+
50
+ If this model helped your research, please cite [Prithvi-EO-2.0](https://arxiv.org/abs/2412.02732) in your publications.
51
+
52
+ ```
53
+ @article{Prithvi-EO-V2-preprint,
54
+ author = {Szwarcman, Daniela and Roy, Sujit and Fraccaro, Paolo and Gíslason, Þorsteinn Elí and Blumenstiel, Benedikt and Ghosal, Rinki and de Oliveira, Pedro Henrique and de Sousa Almeida, João Lucas and Sedona, Rocco and Kang, Yanghui and Chakraborty, Srija and Wang, Sizhe and Kumar, Ankur and Truong, Myscon and Godwin, Denys and Lee, Hyunho and Hsu, Chia-Yu and Akbari Asanjan, Ata and Mujeci, Besart and Keenan, Trevor and Arévolo, Paulo and Li, Wenwen and Alemohammad, Hamed and Olofsson, Pontus and Hain, Christopher and Kennedy, Robert and Zadrozny, Bianca and Cavallaro, Gabriele and Watson, Campbell and Maskey, Manil and Ramachandran, Rahul and Bernabe Moreno, Juan},
55
+ title = {{Prithvi-EO-2.0: A Versatile Multi-Temporal Foundation Model for Earth Observation Applications}},
56
+ journal = {arXiv preprint arXiv:2412.02732},
57
+ year = {2024}
58
+ }
59
+ ```
config.yaml ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # lightning.pytorch==2.4.0
2
+ seed_everything: 0
3
+ trainer:
4
+ accelerator: auto
5
+ strategy: auto
6
+ devices: auto
7
+ num_nodes: 1
8
+ precision: 16-mixed
9
+ logger: true
10
+ callbacks:
11
+ - class_path: lightning.pytorch.callbacks.RichProgressBar
12
+ init_args:
13
+ refresh_rate: 1
14
+ leave: false
15
+ theme:
16
+ description: white
17
+ progress_bar: '#6206E0'
18
+ progress_bar_finished: '#6206E0'
19
+ progress_bar_pulse: '#6206E0'
20
+ batch_progress: white
21
+ time: grey54
22
+ processing_speed: grey70
23
+ metrics: white
24
+ metrics_text_delimiter: ' '
25
+ metrics_format: .3f
26
+ - class_path: lightning.pytorch.callbacks.LearningRateMonitor
27
+ init_args:
28
+ logging_interval: epoch
29
+ log_momentum: false
30
+ log_weight_decay: false
31
+ - class_path: lightning.pytorch.callbacks.EarlyStopping
32
+ init_args:
33
+ monitor: val/loss
34
+ min_delta: 0.0
35
+ patience: 20
36
+ verbose: false
37
+ mode: min
38
+ strict: true
39
+ check_finite: true
40
+ log_rank_zero_only: false
41
+ fast_dev_run: false
42
+ max_epochs: 50
43
+ max_steps: -1
44
+ overfit_batches: 0.0
45
+ check_val_every_n_epoch: 2
46
+ log_every_n_steps: 10
47
+ enable_checkpointing: true
48
+ accumulate_grad_batches: 1
49
+ inference_mode: true
50
+ use_distributed_sampler: true
51
+ detect_anomaly: false
52
+ barebones: false
53
+ sync_batchnorm: false
54
+ reload_dataloaders_every_n_epochs: 0
55
+ default_root_dir: /dccstor/geofm-finetuning/benchmark-geo-bench-paolo/
56
+ model:
57
+ class_path: terratorch.tasks.SemanticSegmentationTask
58
+ init_args:
59
+ model_args:
60
+ backbone_pretrained: true
61
+ backbone: prithvi_eo_v2_300_tl
62
+ decoder: UperNetDecoder
63
+ decoder_channels: 256
64
+ decoder_scale_modules: true
65
+ num_classes: 2
66
+ rescale: true
67
+ backbone_bands:
68
+ - BLUE
69
+ - GREEN
70
+ - RED
71
+ - NIR_NARROW
72
+ - SWIR_1
73
+ - SWIR_2
74
+ head_dropout: 0.1
75
+ necks:
76
+ - name: SelectIndices
77
+ indices:
78
+ - 5
79
+ - 11
80
+ - 17
81
+ - 23
82
+ - name: ReshapeTokensToImage
83
+ model_factory: EncoderDecoderFactory
84
+ loss: ce
85
+ ignore_index: -1
86
+ lr: 0.001
87
+ freeze_backbone: false
88
+ freeze_decoder: false
89
+ plot_on_val: 10
90
+ data:
91
+ class_path: terratorch.datamodules.Sen1Floods11NonGeoDataModule
92
+ init_args:
93
+ data_root: /dccstor/geofm-finetuning/datasets/sen1floods11
94
+ batch_size: 16
95
+ num_workers: 8
96
+ bands:
97
+ - BLUE
98
+ - GREEN
99
+ - RED
100
+ - NIR_NARROW
101
+ - SWIR_1
102
+ - SWIR_2
103
+ train_transform:
104
+ - class_path: albumentations.Resize
105
+ init_args:
106
+ height: 448
107
+ width: 448
108
+ interpolation: 1
109
+ always_apply: false
110
+ p: 1
111
+ - class_path: albumentations.RandomCrop
112
+ init_args:
113
+ height: 224
114
+ width: 224
115
+ always_apply: false
116
+ p: 1.0
117
+ - class_path: albumentations.HorizontalFlip
118
+ init_args:
119
+ always_apply: false
120
+ p: 0.5
121
+ - class_path: albumentations.VerticalFlip
122
+ init_args:
123
+ always_apply: false
124
+ p: 0.5
125
+ - class_path: albumentations.pytorch.ToTensorV2
126
+ init_args:
127
+ transpose_mask: false
128
+ always_apply: true
129
+ p: 1.0
130
+ val_transform:
131
+ - class_path: albumentations.Resize
132
+ init_args:
133
+ height: 448
134
+ width: 448
135
+ interpolation: 1
136
+ always_apply: false
137
+ p: 1
138
+ - class_path: albumentations.pytorch.ToTensorV2
139
+ init_args:
140
+ transpose_mask: false
141
+ always_apply: true
142
+ p: 1.0
143
+ test_transform:
144
+ - class_path: albumentations.Resize
145
+ init_args:
146
+ height: 448
147
+ width: 448
148
+ interpolation: 1
149
+ always_apply: false
150
+ p: 1
151
+ - class_path: albumentations.pytorch.ToTensorV2
152
+ init_args:
153
+ transpose_mask: false
154
+ always_apply: true
155
+ p: 1.0
156
+ drop_last: true
157
+ constant_scale: 0.0001
158
+ no_data_replace: 0.0
159
+ no_label_replace: -1
160
+ use_metadata: false
161
+ out_dtype: int16
162
+ deploy_config_file: true
163
+ optimizer:
164
+ class_path: torch.optim.AdamW
165
+ init_args:
166
+ lr: 5.0e-05
167
+ betas:
168
+ - 0.9
169
+ - 0.999
170
+ eps: 1.0e-08
171
+ weight_decay: 0.05
172
+ amsgrad: false
173
+ maximize: false
174
+ capturable: false
175
+ differentiable: false
176
+ lr_scheduler:
177
+ class_path: torch.optim.lr_scheduler.CosineAnnealingLR
178
+ init_args:
179
+ T_max: 50
180
+ eta_min: 0
181
+ last_epoch: -1
182
+ verbose: deprecated
model_inference.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+ import os
4
+ from typing import List, Union
5
+ import re
6
+ import datetime
7
+ import numpy as np
8
+ import rasterio
9
+ import torch
10
+ import yaml
11
+ from einops import rearrange
12
+ from terratorch.cli_tools import LightningInferenceModel
13
+
14
+ NO_DATA = -9999
15
+ NO_DATA_FLOAT = 0.0001
16
+ OFFSET = 0
17
+ PERCENTILE = 99
18
+
19
+
20
+ def process_channel_group(orig_img, channels):
21
+ """
22
+ Args:
23
+ orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
24
+ channels: list of indices representing RGB channels.
25
+
26
+ Returns:
27
+ torch.Tensor with shape (num_channels, height, width) for original image
28
+ """
29
+
30
+ orig_img = orig_img[channels, ...]
31
+ valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
32
+ valid_mask[orig_img == NO_DATA_FLOAT] = False
33
+
34
+
35
+ # Rescale (enhancing contrast)
36
+ max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
37
+ min_value = OFFSET
38
+
39
+ orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)
40
+
41
+ # No data as zeros
42
+ orig_img[~valid_mask] = 0
43
+
44
+ return orig_img
45
+
46
+
47
+ def read_geotiff(file_path: str):
48
+ """Read all bands from *file_path* and return image + meta info.
49
+
50
+ Args:
51
+ file_path: path to image file.
52
+
53
+ Returns:
54
+ np.ndarray with shape (bands, height, width)
55
+ meta info dict
56
+ """
57
+
58
+ with rasterio.open(file_path) as src:
59
+ img = src.read()
60
+ meta = src.meta
61
+ try:
62
+ coords = src.lnglat()
63
+ except:
64
+ # Cannot read coords
65
+ coords = None
66
+
67
+ return img, meta, coords
68
+
69
+
70
+ def save_geotiff(image, output_path: str, meta: dict):
71
+ """Save multi-band image in Geotiff file.
72
+
73
+ Args:
74
+ image: np.ndarray with shape (bands, height, width)
75
+ output_path: path where to save the image
76
+ meta: dict with meta info.
77
+ """
78
+
79
+ with rasterio.open(output_path, "w", **meta) as dest:
80
+ for i in range(image.shape[0]):
81
+ dest.write(image[i, :, :], i + 1)
82
+
83
+ return
84
+
85
+
86
+ def _convert_np_uint8(float_image: torch.Tensor):
87
+ image = float_image.numpy() * 255.0
88
+ image = image.astype(dtype=np.uint8)
89
+
90
+ return image
91
+
92
+
93
+ def load_example(
94
+ file_paths: List[str],
95
+ mean: List[float] = None,
96
+ std: List[float] = None,
97
+ indices: Union[list[int], None] = None,
98
+ ):
99
+ """Build an input example by loading images in *file_paths*.
100
+
101
+ Args:
102
+ file_paths: list of file paths .
103
+ mean: list containing mean values for each band in the images in *file_paths*.
104
+ std: list containing std values for each band in the images in *file_paths*.
105
+
106
+ Returns:
107
+ np.array containing created example
108
+ list of meta info for each image in *file_paths*
109
+ """
110
+
111
+ imgs = []
112
+ metas = []
113
+ temporal_coords = []
114
+ location_coords = []
115
+
116
+ for file in file_paths:
117
+ img, meta, coords = read_geotiff(file)
118
+
119
+ # Rescaling (don't normalize on nodata)
120
+ img = np.moveaxis(img, 0, -1) # channels last for rescaling
121
+ if indices is not None:
122
+ img = img[..., indices]
123
+ if mean is not None and std is not None:
124
+ img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
125
+
126
+ imgs.append(img)
127
+ metas.append(meta)
128
+ if coords is not None:
129
+ location_coords.append(coords)
130
+
131
+ try:
132
+ match = re.search(r'(\d{7,8}T\d{6})', file)
133
+ if match:
134
+ year = int(match.group(1)[:4])
135
+ julian_day = match.group(1).split('T')[0][4:]
136
+ if len(julian_day) == 3:
137
+ julian_day = int(julian_day)
138
+ else:
139
+ julian_day = datetime.datetime.strptime(julian_day, '%m%d').timetuple().tm_yday
140
+ temporal_coords.append([year, julian_day])
141
+ except Exception as e:
142
+ print(f'Could not extract timestamp for {file} ({e})')
143
+
144
+ imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
145
+ imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
146
+ imgs = np.expand_dims(imgs, axis=0) # add batch di
147
+
148
+ return imgs, temporal_coords, location_coords, metas
149
+
150
+
151
+ def run_model(input_data, temporal_coords, location_coords, model, datamodule, img_size):
152
+ # Reflect pad if not divisible by img_size
153
+ original_h, original_w = input_data.shape[-2:]
154
+ pad_h = (img_size - (original_h % img_size)) % img_size
155
+ pad_w = (img_size - (original_w % img_size)) % img_size
156
+ input_data = np.pad(
157
+ input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
158
+ )
159
+
160
+ # Build sliding window
161
+
162
+ batch_size = 1
163
+ batch = torch.tensor(input_data, device="cpu")
164
+ windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
165
+ h1, w1 = windows.shape[3:5]
166
+ windows = rearrange(
167
+ windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
168
+ )
169
+
170
+ # Split into batches if number of windows > batch_size
171
+ num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
172
+ windows = torch.tensor_split(windows, num_batches, dim=0)
173
+
174
+ if temporal_coords:
175
+ temporal_coords = torch.Tensor(temporal_coords, device=model.device).unsqueeze(0)
176
+ else:
177
+ temporal_coords = None
178
+ if location_coords:
179
+ location_coords = torch.Tensor(location_coords[0], device=model.device).unsqueeze(0)
180
+ else:
181
+ location_coords = None
182
+
183
+ # Run model
184
+ pred_imgs = []
185
+ for x in windows:
186
+ # Apply standardization
187
+ x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1,2,0))
188
+ x = datamodule.aug(x['image'])
189
+
190
+ with torch.no_grad():
191
+ x = x.to(model.device)
192
+ pred = model(x, temporal_coords=temporal_coords, location_coords=location_coords)
193
+ pred = pred.output.detach().cpu()
194
+
195
+ y_hat = pred.argmax(dim=1)
196
+
197
+ y_hat = torch.nn.functional.interpolate(y_hat.unsqueeze(1).float(), size=img_size, mode="nearest")
198
+
199
+ pred_imgs.append(y_hat)
200
+
201
+ pred_imgs = torch.concat(pred_imgs, dim=0)
202
+
203
+ # Build images from patches
204
+ pred_imgs = rearrange(
205
+ pred_imgs,
206
+ "(b h1 w1) c h w -> b c (h1 h) (w1 w)",
207
+ h=img_size,
208
+ w=img_size,
209
+ b=1,
210
+ c=1,
211
+ h1=h1,
212
+ w1=w1,
213
+ )
214
+
215
+ # Cut padded area back to original size
216
+ pred_imgs = pred_imgs[..., :original_h, :original_w]
217
+
218
+ # Squeeze (batch size 1)
219
+ pred_imgs = pred_imgs[0]
220
+
221
+ return pred_imgs
222
+
223
+
224
+ def main(
225
+ data_file: str,
226
+ config: str,
227
+ checkpoint: str,
228
+ output_dir: str,
229
+ rgb_outputs: bool,
230
+ input_indices: list[int] = None,
231
+ ):
232
+ os.makedirs(output_dir, exist_ok=True)
233
+
234
+ with open(config, "r") as f:
235
+ config_dict = yaml.safe_load(f)
236
+
237
+ # Load model ---------------------------------------------------------------------------------
238
+
239
+ lightning_model = LightningInferenceModel.from_config(config, checkpoint)
240
+ img_size = 512 # Size of Sen1Floods11
241
+
242
+ # Loading data ---------------------------------------------------------------------------------
243
+
244
+ input_data, temporal_coords, location_coords, meta_data = load_example(
245
+ file_paths=[data_file], indices=input_indices,
246
+ )
247
+
248
+ meta_data = meta_data[0] # only one image
249
+
250
+ if input_data.mean() > 1:
251
+ input_data = input_data / 10000 # Convert to range 0-1
252
+
253
+ # Running model --------------------------------------------------------------------------------
254
+
255
+ lightning_model.model.eval()
256
+
257
+ channels = [config_dict['data']['init_args']['bands'].index(b) for b in ["RED", "GREEN", "BLUE"]] # BGR -> RGB
258
+
259
+ pred = run_model(input_data, temporal_coords, location_coords,
260
+ lightning_model.model, lightning_model.datamodule, img_size)
261
+
262
+ # Save pred
263
+ meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
264
+ pred_file = os.path.join(output_dir, f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
265
+ save_geotiff(_convert_np_uint8(pred), pred_file, meta_data)
266
+
267
+ # Save image + pred
268
+ meta_data.update(count=3, dtype="uint8", compress="lzw", nodata=0)
269
+
270
+ if input_data.mean() < 1:
271
+ input_data = input_data * 10000 # Scale to 0-10000
272
+
273
+ rgb_orig = process_channel_group(
274
+ orig_img=torch.Tensor(input_data[0, :, 0, ...]),
275
+ channels=channels,
276
+ )
277
+
278
+ pred[pred == 0.] = np.nan
279
+ img_pred = rgb_orig * 0.7 + pred * 0.3
280
+ img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()]
281
+
282
+ img_pred_file = os.path.join(output_dir, f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
283
+ save_geotiff(
284
+ image=_convert_np_uint8(img_pred),
285
+ output_path=img_pred_file,
286
+ meta=meta_data,
287
+ )
288
+
289
+ # Save image rgb
290
+ if rgb_outputs:
291
+ rgb_file = os.path.join(output_dir, f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
292
+ save_geotiff(
293
+ image=_convert_np_uint8(rgb_orig),
294
+ output_path=rgb_file,
295
+ meta=meta_data,
296
+ )
297
+
298
+ print("Done!")
299
+
300
+
301
+ if __name__ == "__main__":
302
+ parser = argparse.ArgumentParser("MAE run inference", add_help=False)
303
+
304
+ parser.add_argument(
305
+ "--data_file",
306
+ type=str,
307
+ default="examples/India_900498_S2Hand.tif",
308
+ help="Path to the file.",
309
+ )
310
+ parser.add_argument(
311
+ "--config",
312
+ "-c",
313
+ type=str,
314
+ default="config.yaml",
315
+ help="Path to yaml file containing model parameters.",
316
+ )
317
+ parser.add_argument(
318
+ "--checkpoint",
319
+ type=str,
320
+ default="Prithvi-EO-V2-300M-TL-Sen1Floods11.pt",
321
+ help="Path to a checkpoint file to load from.",
322
+ )
323
+ parser.add_argument(
324
+ "--output_dir",
325
+ type=str,
326
+ default="output",
327
+ help="Path to the directory where to save outputs.",
328
+ )
329
+ parser.add_argument(
330
+ "--input_indices",
331
+ default=[1,2,3,8,11,12],
332
+ type=int,
333
+ nargs="+",
334
+ help="0-based indices of the six Prithvi channels to be selected from the input. By default selects [1,2,3,8,11,12] for S2L1C data.",
335
+ )
336
+ parser.add_argument(
337
+ "--rgb_outputs",
338
+ action="store_true",
339
+ help="If present, output files will only contain RGB channels. "
340
+ "Otherwise, all bands will be saved.",
341
+ )
342
+ args = parser.parse_args()
343
+
344
+ main(**vars(args))
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ einops
5
+ rasterio
6
+ git+https://github.com/IBM/terratorch.git