KPLabs commited on
Commit
97a17c2
·
verified ·
1 Parent(s): 45b7a8f

Upload folder using huggingface_hub

Browse files
Files changed (30) hide show
  1. .gitattributes +1 -0
  2. Methane_benchmark_patches_summary_v3.xlsx +3 -0
  3. README.md +27 -0
  4. classification/config/methane_classification_datamodule.py +116 -0
  5. classification/config/methane_classification_dataset.py +79 -0
  6. classification/config/train.yaml +85 -0
  7. classification/script/methane_classification_datamodule.py +71 -0
  8. classification/script/methane_classification_dataset.py +82 -0
  9. classification/script/train_classification_fine_tuning.py +329 -0
  10. classification_with_text/calculate_embeddings.py +51 -0
  11. classification_with_text/combined_caption_embeddings.csv +0 -0
  12. classification_with_text/script/methan_text_dataset.py +81 -0
  13. classification_with_text/script/methane_text_datamodule.py +72 -0
  14. classification_with_text/script/train_text.py +448 -0
  15. intuition1_classification_finetuning/config/methane_simulated_datamodule.py +116 -0
  16. intuition1_classification_finetuning/config/methane_simulated_dataset.py +66 -0
  17. intuition1_classification_finetuning/config/train.yaml +66 -0
  18. intuition1_classification_finetuning/script/methane_simulated_datamodule.py +72 -0
  19. intuition1_classification_finetuning/script/methane_simulated_dataset.py +81 -0
  20. intuition1_classification_finetuning/script/train_simulated_I1.py +309 -0
  21. sentinel2_classification_finetuning/config/methane_simulated_datamodule.py +119 -0
  22. sentinel2_classification_finetuning/config/methane_simulated_dataset.py +70 -0
  23. sentinel2_classification_finetuning/config/train.yaml +67 -0
  24. sentinel2_classification_finetuning/script/inference_s2_simulated.py +241 -0
  25. sentinel2_classification_finetuning/script/methane_simulated_datamodule.py +72 -0
  26. sentinel2_classification_finetuning/script/methane_simulated_dataset.py +81 -0
  27. sentinel2_classification_finetuning/script/train_simulated_s2.py +314 -0
  28. urban_inference/methane_urban_datamodule.py +71 -0
  29. urban_inference/methane_urban_dataset.py +82 -0
  30. urban_inference/urban_inference.py +221 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Methane_benchmark_patches_summary_v3.xlsx filter=lfs diff=lfs merge=lfs -text
Methane_benchmark_patches_summary_v3.xlsx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:522e4342850c50e22decd218f5bd546493931f57cd10642cf6d42d2224ad4890
3
+ size 349508
README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FAST-EO Use Case 2 - Methane Detection
2
+
3
+ This directory contains all data and code necessary to recreate experiments conducted for fine-tuning Terramind-Base to detect methane in satellite images. It includes five distinct experiments along with their corresponding datasets. The attached Methane_benchmark_patches_summary_v3.xlsx file provides descriptions for every patch extracted from the Methane Benchmark Dataset (MBD) and defines the fold splits to ensure non-overlapping data. This Excel file is used by the runner scripts to partition the data, typically reserving one fold for testing.
4
+
5
+ Each script includes usage instructions which can be accessed by applying the --help (or -h) flag.
6
+
7
+ ## Important: Ensure the Terramind package is installed before running any experiments.
8
+
9
+ ## Experiment 1: Fine tuning on Methane Benchmark Dataset
10
+
11
+ The first experiment is fine tuning the model on Methane Benchmark Dataset. The dataset has been attatched in the directory `MBD_nan_S2_zscore`, and has been already normalized. The code for running the training is located in the `classification` directory, along with neccessary `dataset` and `dataloader` classes.
12
+
13
+ ## Experiment 2: Fine tuning on MBD with text captions
14
+
15
+ This experiment contains a modified verion of the Terramind-Based model, which concatinates the textual embeddings of the text captions for every image, with the visual embeddings of the base model. The text embeddings are calculated using the `all-MiniLM-L6-v2` model. All the code, along with embeddings calculation, and data, is available in the `classification_with_text` directory. The original captions are located in `classification_with_text/MBD_text`, and the embeddings are located inside the `combined_caption_embeddings.csv` file.
16
+
17
+ ## Experiment 3: Fine tuning and inference on Sentinel 2 with simulated atmospheric conditions
18
+
19
+ This experiment checks how the Terramind-Base behaves on the Sentinel-2 data with simulated atmospheric conditions. The simulated data is both in the Top-of-Atmosphere and Bottom-of-Atmsphere variants. The model can be both trained on this data, or only run on it, to test how good it is at generalization when trained on different data.
20
+
21
+ ## Experiment 4: Fine tuning and inference on Intuition 1 with simulated atmosphric conditions
22
+
23
+ This experiment checks how the Terramind-Base behaves on the Intuition-1 data with simulated atmospheric conditions. The simulated data is both in the Top-of-Atmosphere and Bottom-of-Atmsphere variants. The model can be both trained on this data, or only run on it, to test how good it is at generalization when trained on different data.
24
+
25
+ ## Experiment 5: Testing the detector on urban dataset without methane
26
+
27
+ The urban dataset has been prepared to check whether the models really learned to detect methane from multispectral data or just look for urban signatures in the images. All of the images in this dataset do not contain methane, the goal is to run the models and see how many false positives are returned. Python script for loading and running the models were attatched.
classification/config/methane_classification_datamodule.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import albumentations as A
3
+ from typing import Optional, List
4
+ from sklearn.model_selection import train_test_split
5
+ from torch.utils.data import DataLoader
6
+ from torchgeo.datamodules import NonGeoDataModule
7
+ from methane_classification_dataset import MethaneClassificationDataset
8
+
9
+ class MethaneClassificationDataModule(NonGeoDataModule):
10
+ def __init__(
11
+ self,
12
+ data_root: str,
13
+ excel_file: str,
14
+ batch_size: int = 8,
15
+ num_workers: int = 0,
16
+ val_split: float = 0.2,
17
+ seed: int = 42,
18
+ **kwargs
19
+ ):
20
+ # We pass "NonGeoDataset" just to satisfy the parent class,
21
+ # but we instantiate specific datasets in setup()
22
+ super().__init__(MethaneClassificationDataset, batch_size, num_workers, **kwargs)
23
+
24
+ self.data_root = data_root
25
+ self.excel_file = excel_file
26
+ self.val_split = val_split
27
+ self.seed = seed
28
+ self.batch_size = batch_size
29
+ self.num_workers = num_workers
30
+
31
+ # State variables for paths
32
+ self.train_paths = []
33
+ self.val_paths = []
34
+
35
+ def _get_training_transforms(self):
36
+ """Internal definition of training transforms"""
37
+ return A.Compose([
38
+ A.ElasticTransform(p=0.25),
39
+ A.RandomRotate90(p=0.5),
40
+ A.Flip(p=0.5),
41
+ A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5)
42
+ ])
43
+
44
+ def setup(self, stage: str = None):
45
+ # 1. Read the Excel File
46
+ try:
47
+ df = pd.read_csv(self.excel_file) if self.excel_file.endswith('.csv') else pd.read_excel(self.excel_file)
48
+ except Exception as e:
49
+ raise RuntimeError(f"Failed to load summary file: {e}")
50
+
51
+ # 2. Filter valid paths (checking if Fold column exists or just using all data)
52
+ # Assuming we just use all data in the file and split it 80/20 here.
53
+ # If you need specific Fold filtering, add that logic here.
54
+ all_paths = df['Filename'].tolist()
55
+
56
+ # 3. Perform the Split
57
+ self.train_paths, self.val_paths = train_test_split(
58
+ all_paths,
59
+ test_size=self.val_split,
60
+ random_state=self.seed
61
+ )
62
+
63
+ # 4. Instantiate Datasets
64
+ if stage in ("fit", "train"):
65
+ self.train_dataset = MethaneClassificationDataset(
66
+ root_dir=self.data_root,
67
+ excel_file=self.excel_file,
68
+ paths=self.train_paths,
69
+ transform=self._get_training_transforms(),
70
+ )
71
+
72
+ if stage in ("fit", "validate", "val"):
73
+ self.val_dataset = MethaneClassificationDataset(
74
+ root_dir=self.data_root,
75
+ excel_file=self.excel_file,
76
+ paths=self.val_paths,
77
+ transform=None, # No transforms for validation
78
+ )
79
+
80
+ if stage in ("test", "predict"):
81
+ # For testing, you might want to use a specific hold-out set
82
+ # For now, reusing val_paths or you can add logic to load a test fold
83
+ self.test_dataset = MethaneClassificationDataset(
84
+ root_dir=self.data_root,
85
+ excel_file=self.excel_file,
86
+ paths=self.val_paths,
87
+ transform=None,
88
+ )
89
+
90
+
91
+ def train_dataloader(self):
92
+ return DataLoader(
93
+ self.train_dataset,
94
+ batch_size=self.batch_size,
95
+ shuffle=True,
96
+ num_workers=self.num_workers,
97
+ drop_last=True
98
+ )
99
+
100
+ def val_dataloader(self):
101
+ return DataLoader(
102
+ self.val_dataset,
103
+ batch_size=self.batch_size,
104
+ shuffle=False,
105
+ num_workers=self.num_workers,
106
+ drop_last=True
107
+ )
108
+
109
+ def test_dataloader(self):
110
+ return DataLoader(
111
+ self.test_dataset,
112
+ batch_size=self.batch_size,
113
+ shuffle=False,
114
+ num_workers=self.num_workers,
115
+ drop_last=True
116
+ )
classification/config/methane_classification_dataset.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import rasterio
3
+ import torch
4
+ from torchgeo.datasets import NonGeoDataset
5
+ from torch.utils.data import DataLoader
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ def min_max_normalize(data, new_min=0, new_max=1):
11
+ data = np.array(data, dtype=np.float32) # Convert to NumPy array
12
+
13
+ # Handle NaN, Inf values
14
+ data = np.nan_to_num(data, nan=np.nanmin(data), posinf=np.max(data), neginf=np.min(data))
15
+
16
+ old_min, old_max = np.min(data), np.max(data)
17
+
18
+ if old_max == old_min: # Prevent division by zero
19
+ return np.full_like(data, new_min, dtype=np.float32) # Uniform array
20
+
21
+ return (data - old_min) / (old_max - old_min + 1e-10) * (new_max - new_min) + new_min
22
+
23
+ class MethaneClassificationDataset(NonGeoDataset):
24
+ def __init__(self, root_dir, excel_file, paths, transform=None, mean=None, std=None):
25
+ super().__init__()
26
+ self.root_dir = root_dir
27
+ self.transform = transform
28
+ self.data_paths = []
29
+ self.mean = mean if mean else [0.485] * 12 # Default mean if not provided
30
+ self.std = std if std else [0.229] * 12 # Default std if not provided
31
+
32
+ # Collect paths for labelbinary.tif and sCube.tif in selected folders
33
+ for folder_name in paths:
34
+ subdir_path = os.path.join(root_dir, folder_name)
35
+ if os.path.isdir(subdir_path):
36
+ label_path = os.path.join(subdir_path, 'labelbinary.tif')
37
+ scube_path = os.path.join(subdir_path, 'sCube.tif')
38
+
39
+ if os.path.exists(label_path) and os.path.exists(scube_path):
40
+ self.data_paths.append((label_path, scube_path))
41
+
42
+ def __len__(self):
43
+ return len(self.data_paths)
44
+
45
+ def __getitem__(self, idx):
46
+ label_path, scube_path = self.data_paths[idx]
47
+
48
+ # Load the label image (single band)
49
+ with rasterio.open(label_path) as label_src:
50
+ label_image = label_src.read(1) # Shape: [512, 512]
51
+
52
+ # Load the sCube image (multi-band), drop the first band
53
+ with rasterio.open(scube_path) as scube_src:
54
+ scube_image = scube_src.read() # Shape: [13, 512, 512]
55
+
56
+ scube_image = scube_image[[0,1,2,3,4,5,6,7,8,9,11,12], :, :] # Drop first band → Shape: [12, 512, 512]
57
+
58
+ # Convert to PyTorch tensors
59
+ scube_tensor = torch.from_numpy(scube_image).float() # Shape: [12, 512, 512]
60
+ label_tensor = torch.from_numpy(label_image).float() # Shape: [512, 512]
61
+
62
+ # Resize to [12, 224, 224] and [224, 224] respectively
63
+ scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
64
+ label_tensor = F.interpolate(label_tensor.unsqueeze(0).unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0)
65
+
66
+ label_tensor = label_tensor.clip(0, 1) # Clip values to [0, 1]
67
+ scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0) # Replace NaNs with 0
68
+
69
+ # Convert labels to binary
70
+ contains_methane = (label_tensor > 0).any().long()
71
+
72
+ # Apply transformations (if any)
73
+ if self.transform:
74
+ transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0)))
75
+ scube_tensor = transformed['image'].transpose(2, 0, 1) # Convert back to [C, H, W]
76
+
77
+
78
+ return {'image': scube_tensor, 'label': contains_methane, 'gt': label_image, 'sample': scube_path.split('/')[3]}
79
+
classification/config/train.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Global Seed for reproducibility (matches script's set_seed)
2
+ seed_everything: 42
3
+
4
+ # ------------------------------------------------------------------
5
+ # Trainer Configuration
6
+ # ------------------------------------------------------------------
7
+ trainer:
8
+ accelerator: auto # Handles "cuda" if available, else "cpu"
9
+ strategy: auto
10
+ devices: 1
11
+ max_epochs: 100 # Matches args.epochs
12
+ default_root_dir: ./checkpoints
13
+
14
+ # Callbacks to replicate the script's checkpointing and logging
15
+ callbacks:
16
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
17
+ init_args:
18
+ monitor: val/loss
19
+ mode: min
20
+ save_top_k: 1
21
+ filename: "best_model"
22
+ save_last: true # Saves 'last.ckpt' (similar to 'final_model.pth')
23
+
24
+ - class_path: lightning.pytorch.callbacks.LearningRateMonitor
25
+ init_args:
26
+ logging_interval: epoch
27
+
28
+ # ------------------------------------------------------------------
29
+ # Model Configuration (TerraMind + UperNet)
30
+ # ------------------------------------------------------------------
31
+ model:
32
+ class_path: terratorch.tasks.ClassificationTask
33
+ init_args:
34
+ model_factory: EncoderDecoderFactory
35
+ loss: ce
36
+ ignore_index: -1
37
+ lr: 1.0e-5
38
+ # Optimizer settings matching _init_optimizer
39
+ optimizer: AdamW
40
+ optimizer_hparams:
41
+ weight_decay: 0.05
42
+
43
+ # Scheduler settings matching ReduceLROnPlateau
44
+ scheduler: ReduceLROnPlateau
45
+ scheduler_hparams:
46
+ mode: min
47
+ patience: 5
48
+
49
+ # --------------------------------------------------------------
50
+ # Model Architecture (Exact match to script's model_config)
51
+ # --------------------------------------------------------------
52
+ model_args:
53
+ backbone: terramind_v1_base
54
+ backbone_pretrained: true
55
+ backbone_modalities:
56
+ - S2L2A
57
+ backbone_merge_method: mean
58
+
59
+ decoder: UperNetDecoder
60
+ decoder_scale_modules: true
61
+ decoder_channels: 256
62
+ num_classes: 2
63
+ head_dropout: 0.3
64
+
65
+ # Specific neck configuration for TerraMind
66
+ necks:
67
+ - name: ReshapeTokensToImage
68
+ remove_cls_token: false
69
+ - name: SelectIndices
70
+ indices: [2, 5, 8, 11]
71
+
72
+ # ------------------------------------------------------------------
73
+ # Data Configuration
74
+ # ------------------------------------------------------------------
75
+ data:
76
+ class_path: methane_classification_datamodule.MethaneClassificationDataModule
77
+ init_args:
78
+ data_root: ../../MBD_nan_S2_zscore/MBD_nan_S2_zscore
79
+ excel_file: ../../Methane_benchmark_patches_summary_v3.xlsx
80
+ batch_size: 8
81
+ val_split: 0.2
82
+ seed: 42
83
+ # Note: The procedural train_test_split logic from the script
84
+ # (handling folds/splitting) should be encapsulated inside the
85
+ # DataModule's setup() method for this config to work seamlessly.
classification/script/methane_classification_datamodule.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import rasterio
3
+ import torch
4
+ from torchgeo.datasets import NonGeoDataset
5
+ from torch.utils.data import DataLoader
6
+ from torchgeo.datamodules import NonGeoDataModule
7
+ from methane_classification_dataset import MethaneClassificationDataset
8
+
9
+ class MethaneClassificationDataModule(NonGeoDataModule):
10
+ """
11
+ A DataModule for handling MethaneClassificationDataset
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ data_root: str,
17
+ excel_file: str,
18
+ paths: list,
19
+ batch_size: int = 8,
20
+ num_workers: int = 0,
21
+ train_transform: callable = None,
22
+ val_transform: callable = None,
23
+ test_transform: callable = None,
24
+ **kwargs
25
+ ):
26
+ super().__init__(MethaneClassificationDataset, batch_size, num_workers, **kwargs)
27
+
28
+ self.data_root = data_root
29
+ self.excel_file = excel_file
30
+ self.paths = paths
31
+ self.train_transform = train_transform
32
+ self.val_transform = val_transform
33
+ self.test_transform = test_transform
34
+
35
+ def setup(self, stage: str = None):
36
+ if stage in ("fit", "train"):
37
+ self.train_dataset = MethaneClassificationDataset(
38
+ root_dir=self.data_root,
39
+ excel_file=self.excel_file,
40
+ paths=self.paths,
41
+ transform=self.train_transform,
42
+ )
43
+ if stage in ("fit", "validate", "val"):
44
+ self.val_dataset = MethaneClassificationDataset(
45
+ root_dir=self.data_root,
46
+ excel_file=self.excel_file,
47
+ paths=self.paths,
48
+ transform=self.val_transform,
49
+ )
50
+ if stage in ("test", "predict"):
51
+ self.test_dataset = MethaneClassificationDataset(
52
+ root_dir=self.data_root,
53
+ excel_file=self.excel_file,
54
+ paths=self.paths,
55
+ transform=self.test_transform,
56
+ )
57
+
58
+ def train_dataloader(self):
59
+ return DataLoader(
60
+ self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True
61
+ )
62
+
63
+ def val_dataloader(self):
64
+ return DataLoader(
65
+ self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
66
+ )
67
+
68
+ def test_dataloader(self):
69
+ return DataLoader(
70
+ self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
71
+ )
classification/script/methane_classification_dataset.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import rasterio
3
+ import torch
4
+ from torchgeo.datasets import NonGeoDataset
5
+ from torch.utils.data import DataLoader
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ def min_max_normalize(data, new_min=0, new_max=1):
11
+ data = np.array(data, dtype=np.float32) # Convert to NumPy array
12
+
13
+ # Handle NaN, Inf values
14
+ data = np.nan_to_num(data, nan=np.nanmin(data), posinf=np.max(data), neginf=np.min(data))
15
+
16
+ old_min, old_max = np.min(data), np.max(data)
17
+
18
+ if old_max == old_min: # Prevent division by zero
19
+ return np.full_like(data, new_min, dtype=np.float32) # Uniform array
20
+
21
+ return (data - old_min) / (old_max - old_min + 1e-10) * (new_max - new_min) + new_min
22
+
23
+ class MethaneClassificationDataset(NonGeoDataset):
24
+ def __init__(self, root_dir, excel_file, paths, transform=None, mean=None, std=None):
25
+ super().__init__()
26
+ self.root_dir = root_dir
27
+ self.transform = transform
28
+ self.data_paths = []
29
+ self.mean = mean if mean else [0.485] * 12 # Default mean if not provided
30
+ self.std = std if std else [0.229] * 12 # Default std if not provided
31
+
32
+ # Collect paths for labelbinary.tif and sCube.tif in selected folders
33
+ for folder_name in paths:
34
+ subdir_path = os.path.join(root_dir, folder_name)
35
+ if os.path.isdir(subdir_path):
36
+ label_path = os.path.join(subdir_path, 'labelbinary.tif')
37
+ scube_path = os.path.join(subdir_path, 'sCube.tif')
38
+
39
+ if os.path.exists(label_path) and os.path.exists(scube_path):
40
+ self.data_paths.append((label_path, scube_path))
41
+
42
+ def __len__(self):
43
+ return len(self.data_paths)
44
+
45
+ def __getitem__(self, idx):
46
+ label_path, scube_path = self.data_paths[idx]
47
+
48
+ # Load the label image (single band)
49
+ with rasterio.open(label_path) as label_src:
50
+ label_image = label_src.read(1) # Shape: [512, 512]
51
+
52
+ # Load the sCube image (multi-band), drop the first band
53
+ with rasterio.open(scube_path) as scube_src:
54
+ scube_image = scube_src.read() # Shape: [13, 512, 512]
55
+ # Zrobić tak żeby nie było 10 bandu
56
+
57
+ scube_image = scube_image[[0,1,2,3,4,5,6,7,8,9,11,12], :, :] # Drop first band → Shape: [12, 512, 512]
58
+
59
+ # Convert to PyTorch tensors
60
+ scube_tensor = torch.from_numpy(scube_image).float() # Shape: [12, 512, 512]
61
+ label_tensor = torch.from_numpy(label_image).float() # Shape: [512, 512]
62
+
63
+ # Resize to [12, 224, 224] and [224, 224] respectively
64
+ scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
65
+ label_tensor = F.interpolate(label_tensor.unsqueeze(0).unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0)
66
+
67
+ label_tensor = label_tensor.clip(0, 1) # Clip values to [0, 1]
68
+ scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0) # Replace NaNs with 0
69
+
70
+ # Convert labels to binary
71
+ contains_methane = (label_tensor > 0).any().long()
72
+
73
+ # Convert to one-hot encoding
74
+ one_hot_label = F.one_hot(contains_methane, num_classes=2).float()
75
+
76
+ # Apply transformations (if any)
77
+ if self.transform:
78
+ transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0)))
79
+ scube_tensor = transformed['image'].transpose(2, 0, 1) # Convert back to [C, H, W]
80
+
81
+
82
+ return {'S2L2A': scube_tensor, 'label': one_hot_label, 'gt': label_image, 'sample': scube_path.split('/')[3]}
classification/script/train_classification_fine_tuning.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import csv
4
+ import random
5
+ import warnings
6
+ import time
7
+ from pathlib import Path
8
+ from typing import Dict, List, Tuple, Any, Optional
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.optim as optim
15
+ import albumentations as A
16
+ from torch.utils.data import DataLoader
17
+ from tqdm import tqdm
18
+ from sklearn.model_selection import train_test_split
19
+ from sklearn.metrics import (
20
+ accuracy_score, recall_score, f1_score, matthews_corrcoef, confusion_matrix
21
+ )
22
+ from rasterio.errors import NotGeoreferencedWarning
23
+ import terramind
24
+
25
+ # Local Imports
26
+ from methane_classification_datamodule import MethaneClassificationDataModule
27
+
28
+ # TerraTorch Imports
29
+ from terratorch.tasks import ClassificationTask
30
+
31
+
32
+ # --- Configuration & Setup ---
33
+
34
+ # Configure Logging
35
+ logging.basicConfig(
36
+ level=logging.INFO,
37
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
38
+ datefmt='%Y-%m-%d %H:%M:%S'
39
+ )
40
+ logger = logging.getLogger(__name__)
41
+
42
+ # Suppress Warnings
43
+ logging.getLogger("rasterio._env").setLevel(logging.ERROR)
44
+ warnings.simplefilter("ignore", NotGeoreferencedWarning)
45
+ warnings.filterwarnings("ignore", category=FutureWarning)
46
+
47
+ def set_seed(seed: int = 42):
48
+ """Sets the seed for reproducibility across random, numpy, and torch."""
49
+ random.seed(seed)
50
+ np.random.seed(seed)
51
+ torch.manual_seed(seed)
52
+ if torch.cuda.is_available():
53
+ torch.cuda.manual_seed_all(seed)
54
+
55
+ def get_training_transforms() -> A.Compose:
56
+ """Returns the albumentations training pipeline."""
57
+ return A.Compose([
58
+ A.ElasticTransform(p=0.25),
59
+ A.RandomRotate90(p=0.5),
60
+ A.Flip(p=0.5),
61
+ A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5)
62
+ ])
63
+
64
+ # --- Helper Classes ---
65
+
66
+ class MetricTracker:
67
+ """Accumulates targets and predictions to calculate epoch-level metrics."""
68
+ def __init__(self):
69
+ self.reset()
70
+
71
+ def reset(self):
72
+ self.all_targets = []
73
+ self.all_predictions = []
74
+ self.total_loss = 0.0
75
+ self.steps = 0
76
+
77
+ def update(self, loss: float, targets: torch.Tensor, probabilities: torch.Tensor):
78
+ self.total_loss += loss
79
+ self.steps += 1
80
+ # Store detached cpu numpy arrays to avoid VRAM leaks
81
+ self.all_targets.extend(torch.argmax(targets, dim=1).detach().cpu().numpy())
82
+ self.all_predictions.extend(torch.argmax(probabilities, dim=1).detach().cpu().numpy())
83
+
84
+ def compute(self) -> Dict[str, float]:
85
+ """Calculates aggregate metrics for the accumulated data."""
86
+ if not self.all_targets:
87
+ return {}
88
+
89
+ # Calculate Confusion Matrix elements
90
+ tn, fp, fn, tp = confusion_matrix(self.all_targets, self.all_predictions, labels=[0, 1]).ravel()
91
+
92
+ return {
93
+ "Loss": self.total_loss / max(self.steps, 1),
94
+ "Accuracy": accuracy_score(self.all_targets, self.all_predictions),
95
+ "Specificity": tn / (tn + fp) if (tn + fp) != 0 else 0.0,
96
+ "Sensitivity": recall_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0),
97
+ "F1": f1_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0),
98
+ "MCC": matthews_corrcoef(self.all_targets, self.all_predictions),
99
+ }
100
+
101
+ class MethaneTrainer:
102
+ """
103
+ Handles the training lifecycle: Model setup, Training loop, Validation, and Checkpointing.
104
+ """
105
+ def __init__(self, args: argparse.Namespace):
106
+ self.args = args
107
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
108
+ self.save_dir = Path(args.save_dir) / f'fold{args.test_fold}'
109
+ self.save_dir.mkdir(parents=True, exist_ok=True)
110
+
111
+ self.model = self._init_model()
112
+ self.optimizer, self.scheduler = self._init_optimizer()
113
+ self.criterion = self.task.criterion # Retrieved from the TerraTorch task
114
+
115
+ self.best_val_loss = float('inf')
116
+
117
+ logger.info(f"Trainer initialized on device: {self.device}")
118
+
119
+ def _init_model(self) -> nn.Module:
120
+ """Initializes the TerraTorch Classification Task and Model."""
121
+ model_config = dict(
122
+ backbone="terramind_v1_base",
123
+ backbone_pretrained=True,
124
+ backbone_modalities=["S2L2A"],
125
+ backbone_merge_method="mean",
126
+ decoder="UperNetDecoder",
127
+ decoder_scale_modules=True,
128
+ decoder_channels=256,
129
+ num_classes=2,
130
+ head_dropout=0.3,
131
+ necks=[
132
+ {"name": "ReshapeTokensToImage", "remove_cls_token": False},
133
+ {"name": "SelectIndices", "indices": [2, 5, 8, 11]},
134
+ ],
135
+ )
136
+
137
+ self.task = ClassificationTask(
138
+ model_args=model_config,
139
+ model_factory="EncoderDecoderFactory",
140
+ loss="ce",
141
+ lr=self.args.lr,
142
+ ignore_index=-1,
143
+ optimizer="AdamW",
144
+ optimizer_hparams={"weight_decay": self.args.weight_decay},
145
+ )
146
+ self.task.configure_models()
147
+ self.task.configure_losses()
148
+ return self.task.model.to(self.device)
149
+
150
+ def _init_optimizer(self):
151
+ optimizer = optim.AdamW(
152
+ self.model.parameters(),
153
+ lr=self.args.lr,
154
+ weight_decay=self.args.weight_decay
155
+ )
156
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
157
+ optimizer, mode='min', patience=5, verbose=True
158
+ )
159
+ return optimizer, scheduler
160
+
161
+ def run_epoch(self, dataloader: DataLoader, stage: str = "train") -> Dict[str, float]:
162
+ """Runs a single epoch for either training or validation."""
163
+ is_train = stage == "train"
164
+ self.model.train() if is_train else self.model.eval()
165
+
166
+ tracker = MetricTracker()
167
+
168
+ # Context manager: enable grad only if training
169
+ with torch.set_grad_enabled(is_train):
170
+ pbar = tqdm(dataloader, desc=f" {stage.capitalize()}", leave=False)
171
+
172
+ for batch in pbar:
173
+ inputs = batch['S2L2A'].to(self.device)
174
+ targets = batch['label'].to(self.device)
175
+
176
+ # Forward Pass
177
+ outputs = self.model(x={"S2L2A": inputs})
178
+ probabilities = torch.softmax(outputs.output, dim=1)
179
+ loss = self.criterion(probabilities, targets)
180
+
181
+ if is_train:
182
+ self.optimizer.zero_grad()
183
+ loss.backward()
184
+ self.optimizer.step()
185
+
186
+ # Update metrics
187
+ tracker.update(loss.item(), targets, probabilities)
188
+
189
+ # Update progress bar description with live loss
190
+ pbar.set_postfix(loss=f"{loss.item():.4f}")
191
+
192
+ return tracker.compute()
193
+
194
+ def save_checkpoint(self, filename: str):
195
+ path = self.save_dir / filename
196
+ torch.save(self.model.state_dict(), path)
197
+ logger.info(f"Saved model to {path}")
198
+
199
+ def log_to_csv(self, epoch: int, train_metrics: Dict, val_metrics: Dict):
200
+ """Appends metrics to the CSV log file."""
201
+ csv_path = self.save_dir / 'train_val_metrics.csv'
202
+ file_exists = csv_path.exists()
203
+
204
+ # Define headers based on metric keys
205
+ headers = ['Epoch'] + [f'Train_{k}' for k in train_metrics.keys()] + [f'Val_{k}' for k in val_metrics.keys()]
206
+
207
+ with open(csv_path, mode='a', newline='') as f:
208
+ writer = csv.writer(f)
209
+ if not file_exists:
210
+ writer.writerow(headers)
211
+
212
+ row = [epoch] + list(train_metrics.values()) + list(val_metrics.values())
213
+ writer.writerow(row)
214
+
215
+ def fit(self, train_loader: DataLoader, val_loader: DataLoader):
216
+ """Main training entry point."""
217
+ logger.info(f"Starting training for {self.args.epochs} epochs...")
218
+ start_time = time.time()
219
+
220
+ for epoch in range(1, self.args.epochs + 1):
221
+ logger.info(f"Epoch {epoch}/{self.args.epochs}")
222
+
223
+ # Run Training & Validation
224
+ train_metrics = self.run_epoch(train_loader, stage="train")
225
+ val_metrics = self.run_epoch(val_loader, stage="validate")
226
+
227
+ # Scheduler Step
228
+ self.scheduler.step(val_metrics['Loss'])
229
+
230
+ # Logging
231
+ self.log_to_csv(epoch, train_metrics, val_metrics)
232
+ logger.info(
233
+ f"Train Loss: {train_metrics['Loss']:.4f} | "
234
+ f"Val Loss: {val_metrics['Loss']:.4f} | "
235
+ f"Val F1: {val_metrics['F1']:.4f}"
236
+ )
237
+
238
+ # Save Best Model
239
+ if val_metrics['Loss'] < self.best_val_loss:
240
+ self.best_val_loss = val_metrics['Loss']
241
+ self.save_checkpoint("best_model.pth")
242
+ logger.info(f"--> New best model (Val Loss: {self.best_val_loss:.4f})")
243
+
244
+ # End of training
245
+ self.save_checkpoint("final_model.pth")
246
+ logger.info(f"Training finished in {time.time() - start_time:.2f}s")
247
+
248
+
249
+ # --- Data Utilities ---
250
+
251
+ def get_data_loaders(args) -> Tuple[DataLoader, DataLoader]:
252
+ """Prepares DataModule and returns Train/Val loaders."""
253
+
254
+ # Read Excel and Filter Folds
255
+ try:
256
+ df = pd.read_csv(args.excel_file) if args.excel_file.endswith('.csv') else pd.read_excel(args.excel_file)
257
+ except Exception as e:
258
+ logger.error(f"Failed to load summary file: {e}")
259
+ raise
260
+
261
+ # Determine training pool (all folds except test_fold)
262
+ all_folds = range(1, args.num_folds + 1)
263
+ train_pool_folds = [f for f in all_folds if f != args.test_fold]
264
+
265
+ # Filter filenames
266
+ df_filtered = df[df['Fold'].isin(train_pool_folds)]
267
+ if df_filtered.empty:
268
+ raise ValueError(f"No data found for folds {train_pool_folds}. Check 'Fold' column in Excel.")
269
+
270
+ paths = df_filtered['Filename'].tolist()
271
+
272
+ # 80/20 Split
273
+ train_paths, val_paths = train_test_split(paths, test_size=0.2, random_state=args.seed)
274
+
275
+ logger.info(f"Data Split - Train: {len(train_paths)}, Val: {len(val_paths)} (Test Fold: {args.test_fold})")
276
+
277
+ # Initialize DataModule
278
+ datamodule = MethaneClassificationDataModule(
279
+ data_root=args.root_dir,
280
+ excel_file=args.excel_file,
281
+ batch_size=args.batch_size,
282
+ paths=train_paths,
283
+ train_transform=get_training_transforms(),
284
+ val_transform=None,
285
+ )
286
+
287
+ # Create Loaders
288
+ datamodule.paths = train_paths
289
+ datamodule.setup(stage="fit")
290
+ train_loader = datamodule.train_dataloader()
291
+
292
+ datamodule.paths = val_paths
293
+ datamodule.setup(stage="validate")
294
+ val_loader = datamodule.val_dataloader()
295
+
296
+ return train_loader, val_loader
297
+
298
+
299
+ # --- Main Execution ---
300
+
301
+ def parse_args():
302
+ parser = argparse.ArgumentParser(description="Methane Classification Training with TerraTorch")
303
+
304
+ # Paths
305
+ parser.add_argument('--root_dir', type=str, required=True, help='Root directory for satellite images')
306
+ parser.add_argument('--excel_file', type=str, required=True, help='Path to summary Excel/CSV file')
307
+ parser.add_argument('--save_dir', type=str, default='./checkpoints', help='Directory to save outputs')
308
+
309
+ # Training Config
310
+ parser.add_argument('--epochs', type=int, default=100)
311
+ parser.add_argument('--batch_size', type=int, default=8)
312
+ parser.add_argument('--lr', type=float, default=1e-5)
313
+ parser.add_argument('--weight_decay', type=float, default=0.05)
314
+ parser.add_argument('--num_folds', type=int, default=5)
315
+ parser.add_argument('--test_fold', type=int, default=2, help='Fold ID to hold out for testing')
316
+ parser.add_argument('--seed', type=int, default=42)
317
+
318
+ return parser.parse_args()
319
+
320
+ if __name__ == "__main__":
321
+ args = parse_args()
322
+ set_seed(args.seed)
323
+
324
+ # Prepare Data
325
+ train_loader, val_loader = get_data_loaders(args)
326
+
327
+ # Initialize Trainer and Start
328
+ trainer = MethaneTrainer(args)
329
+ trainer.fit(train_loader, val_loader)
classification_with_text/calculate_embeddings.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pandas as pd
3
+ from sentence_transformers import SentenceTransformer
4
+ from pathlib import Path
5
+ from tqdm import tqdm
6
+
7
+ def extract_caption(text_block):
8
+ for line in text_block.splitlines():
9
+ if "CAPTION:" in line.upper():
10
+ return line.split("CAPTION:")[-1].strip()
11
+ return ""
12
+
13
+ def load_captions_from_files(json_files):
14
+ all_paths = []
15
+ all_captions = []
16
+
17
+ for json_path in tqdm(json_files, desc="Reading files"):
18
+ with open(json_path, 'r', encoding='utf-8') as f:
19
+ data = json.load(f)
20
+
21
+ for img_path, outer_list in data.items():
22
+ if not outer_list or not outer_list[0]:
23
+ continue
24
+ text_block = outer_list[0][0]
25
+ caption = extract_caption(text_block)
26
+ if caption:
27
+ all_paths.append(img_path)
28
+ all_captions.append(caption)
29
+
30
+ return all_paths, all_captions
31
+
32
+ def compute_and_save_embeddings(json_files, output_csv):
33
+ model = SentenceTransformer('all-MiniLM-L6-v2')
34
+ image_paths, captions = load_captions_from_files(json_files)
35
+
36
+ if not captions:
37
+ print("No valid captions found across input files.")
38
+ return
39
+
40
+ embeddings = model.encode(captions, show_progress_bar=True)
41
+ df = pd.DataFrame(embeddings)
42
+ df.insert(0, "image_path", image_paths)
43
+ df.to_csv(output_csv, index=False)
44
+ print(f"Saved {len(df)} embeddings from {len(json_files)} files to {output_csv}")
45
+
46
+ # Example usage
47
+ if __name__ == "__main__":
48
+ import glob
49
+ # Collect multiple JSON files
50
+ files = glob.glob("./MBD_text/*.json") # or provide manually
51
+ compute_and_save_embeddings(files, "combined_caption_embeddings.csv")
classification_with_text/combined_caption_embeddings.csv ADDED
The diff for this file is too large to render. See raw diff
 
classification_with_text/script/methan_text_dataset.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import rasterio
3
+ import torch
4
+ from torchgeo.datasets import NonGeoDataset
5
+ from torch.utils.data import DataLoader
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import pandas as pd
9
+ import json
10
+
11
+ class MethaneTextDataset(NonGeoDataset):
12
+ def __init__(self, root_dir, paths, captions, transform=None):
13
+ super().__init__()
14
+ self.root_dir = root_dir
15
+ self.transform = transform
16
+ self.data_paths = []
17
+ self.captions_dict = captions
18
+
19
+ # Collect paths for labelbinary.tif and sCube.tif in selected folders
20
+ for folder_name in paths:
21
+ subdir_path = os.path.join(root_dir, folder_name)
22
+ if os.path.isdir(subdir_path):
23
+ filename_tokens = subdir_path.split("/")
24
+ label_path = os.path.join(subdir_path, 'labelbinary.tif')
25
+ scube_path = os.path.join(subdir_path, 'sCube.tif')
26
+ if os.path.exists(label_path) and os.path.exists(scube_path):
27
+ self.data_paths.append((label_path, scube_path))
28
+ else:
29
+ print(f"Warning: Missing files in {subdir_path}. Expected labelbinary.tif and sCube.tif.")
30
+
31
+
32
+ def __len__(self):
33
+ return len(self.data_paths)
34
+
35
+ def __getitem__(self, idx):
36
+ label_path, scube_path = self.data_paths[idx]
37
+ filename_tokens = label_path.split("/")
38
+ folder_name = filename_tokens[-2]
39
+
40
+ # Load the label image (single band)
41
+ with rasterio.open(label_path) as label_src:
42
+ label_image = label_src.read(1) # Shape: [512, 512]
43
+
44
+ # Load the sCube image (multi-band), drop the first band
45
+ with rasterio.open(scube_path) as scube_src:
46
+ scube_image = scube_src.read() # Shape: [13, 512, 512]
47
+ scube_image = scube_image[1:, :, :] # Drop first band → Shape: [12, 512, 512]
48
+
49
+ # print(label_image.shape)
50
+ # print(scube_image.shape)
51
+ # Convert to PyTorch tensors
52
+ scube_tensor = torch.from_numpy(scube_image).float() # Shape: [12, 512, 512]
53
+ label_tensor = torch.from_numpy(label_image).float() # Shape: [512, 512]
54
+
55
+ # Resize to [12, 224, 224] and [224, 224] respectively
56
+ scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
57
+ label_tensor = F.interpolate(label_tensor.unsqueeze(0).unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0)
58
+
59
+ label_tensor = label_tensor.clip(0, 1) # Clip values to [0, 1]
60
+ scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0) # Replace NaNs with 0
61
+ # normalized_tensor = min_max_normalize(scube_tensor)
62
+ # Convert labels to binary
63
+ contains_methane = (label_tensor > 0).any().long()
64
+
65
+ # Convert to one-hot encoding
66
+ one_hot_label = F.one_hot(contains_methane, num_classes=2).float()
67
+
68
+ # Apply transformations (if any)
69
+ if self.transform:
70
+ transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0)))
71
+ scube_tensor = transformed['image'].transpose(2, 0, 1) # Convert back to [C, H, W]
72
+
73
+ if folder_name in self.captions_dict:
74
+ caption = self.captions_dict[folder_name]
75
+ else:
76
+ # If the folder name is not in the captions_dict, set a default caption or None
77
+ caption = 'No caption'
78
+ # caption = self.captions_dict[folder_name]
79
+
80
+
81
+ return {'S2L2A': scube_tensor, 'label': one_hot_label, 'caption': caption}
classification_with_text/script/methane_text_datamodule.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import rasterio
3
+ import torch
4
+ from torchgeo.datasets import NonGeoDataset
5
+ from torch.utils.data import DataLoader
6
+ from torchgeo.datamodules import NonGeoDataModule
7
+
8
+ from methan_text_dataset import MethaneTextDataset
9
+
10
+ class MethaneTextDataModule(NonGeoDataModule):
11
+ """
12
+ A DataModule for handling MethaneClassificationDataset
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ data_root: str,
18
+ paths: list,
19
+ captions: list,
20
+ batch_size: int = 8,
21
+ num_workers: int = 0,
22
+ train_transform: callable = None,
23
+ val_transform: callable = None,
24
+ test_transform: callable = None,
25
+ **kwargs
26
+ ):
27
+ super().__init__(MethaneTextDataset, batch_size, num_workers, **kwargs)
28
+
29
+ self.data_root = data_root
30
+ self.paths = paths
31
+ self.captions = captions
32
+ self.train_transform = train_transform
33
+ self.val_transform = val_transform
34
+ self.test_transform = test_transform
35
+
36
+ def setup(self, stage: str = None):
37
+ if stage in ("fit", "train"):
38
+ self.train_dataset = MethaneTextDataset(
39
+ root_dir=self.data_root,
40
+ paths=self.paths,
41
+ captions=self.captions,
42
+ transform=self.train_transform,
43
+ )
44
+ if stage in ("fit", "validate", "val"):
45
+ self.val_dataset = MethaneTextDataset(
46
+ root_dir=self.data_root,
47
+ paths=self.paths,
48
+ captions=self.captions,
49
+ transform=self.val_transform,
50
+ )
51
+ if stage in ("test", "predict"):
52
+ self.test_dataset = MethaneTextDataset(
53
+ root_dir=self.data_root,
54
+ paths=self.paths,
55
+ captions=self.captions,
56
+ transform=self.test_transform,
57
+ )
58
+
59
+ def train_dataloader(self):
60
+ return DataLoader(
61
+ self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True
62
+ )
63
+
64
+ def val_dataloader(self):
65
+ return DataLoader(
66
+ self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
67
+ )
68
+
69
+ def test_dataloader(self):
70
+ return DataLoader(
71
+ self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
72
+ )
classification_with_text/script/train_text.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import csv
4
+ import random
5
+ import warnings
6
+ import time
7
+ import json
8
+ from pathlib import Path
9
+ from functools import partial
10
+ from typing import Dict, List, Tuple, Any, Optional
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.optim as optim
17
+ import albumentations as A
18
+ from torch.utils.data import DataLoader
19
+ from tqdm import tqdm
20
+ from sklearn.model_selection import train_test_split
21
+ from sklearn.metrics import (
22
+ accuracy_score, recall_score, f1_score, matthews_corrcoef, confusion_matrix
23
+ )
24
+ from rasterio.errors import NotGeoreferencedWarning
25
+ from sentence_transformers import SentenceTransformer
26
+
27
+ # --- CRITICAL IMPORTS ---
28
+ import terramind
29
+ from terratorch.tasks import ClassificationTask
30
+ from terratorch.registry import TERRATORCH_BACKBONE_REGISTRY, TERRATORCH_DECODER_REGISTRY
31
+ from terramind.models.terramind_register import build_terrammind_vit
32
+
33
+ # Local Imports
34
+ from methane_text_datamodule import MethaneTextDataModule
35
+
36
+ # --- Configuration & Setup ---
37
+
38
+ logging.basicConfig(
39
+ level=logging.INFO,
40
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
41
+ datefmt='%Y-%m-%d %H:%M:%S'
42
+ )
43
+ logger = logging.getLogger(__name__)
44
+
45
+ logging.getLogger("rasterio._env").setLevel(logging.ERROR)
46
+ warnings.simplefilter("ignore", NotGeoreferencedWarning)
47
+ warnings.filterwarnings("ignore", category=FutureWarning)
48
+
49
+ # --- Global Constants ---
50
+ PRETRAINED_BANDS = {
51
+ 'untok_sen2l2a@224': [
52
+ "COASTAL_AEROSOL", "BLUE", "GREEN", "RED", "RED_EDGE_1", "RED_EDGE_2",
53
+ "RED_EDGE_3", "NIR_BROAD", "NIR_NARROW", "WATER_VAPOR", "SWIR_1", "SWIR_2",
54
+ ]
55
+ }
56
+
57
+ def set_seed(seed: int = 42):
58
+ random.seed(seed)
59
+ np.random.seed(seed)
60
+ torch.manual_seed(seed)
61
+ if torch.cuda.is_available():
62
+ torch.cuda.manual_seed_all(seed)
63
+
64
+ def get_training_transforms() -> A.Compose:
65
+ return A.Compose([
66
+ A.ElasticTransform(p=0.25),
67
+ A.RandomRotate90(p=0.5),
68
+ A.Flip(p=0.5),
69
+ A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5)
70
+ ])
71
+
72
+ # --- Custom Model Components (From Notebook) ---
73
+
74
+ # Initialize Sentence Transformer globally to avoid reloading
75
+ try:
76
+ EMBB_MODEL = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
77
+ # Move to GPU if available for faster encoding during training if needed,
78
+ # though usage in forward() implies dynamic encoding.
79
+ if torch.cuda.is_available():
80
+ EMBB_MODEL = EMBB_MODEL.to("cuda")
81
+ except Exception as e:
82
+ logger.warning(f"Could not load SentenceTransformer: {e}")
83
+ EMBB_MODEL = None
84
+
85
+ class TerraMindWithText(nn.Module):
86
+ def __init__(self, terramind_kwargs: dict):
87
+ super().__init__()
88
+ self.terramind = build_terrammind_vit(
89
+ variant='terramind_v1_base',
90
+ encoder_depth=12,
91
+ dim=768,
92
+ num_heads=12,
93
+ mlp_ratio=4,
94
+ qkv_bias=False,
95
+ proj_bias=False,
96
+ mlp_bias=False,
97
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
98
+ act_layer=nn.SiLU,
99
+ gated_mlp=True,
100
+ pretrained_bands=PRETRAINED_BANDS,
101
+ **terramind_kwargs
102
+ )
103
+ self.out_channels = [768] * 12
104
+ # self.project = nn.Linear(768 + 512, 768*192) # Referenced in notebook but seemingly unused in forward
105
+
106
+ def forward(self, x, captions):
107
+ vision_features = self.terramind(x) # shape: (batch_size, 768)
108
+
109
+ # Encode captions
110
+ # Note: embb_model.encode returns numpy or tensor. Ensure it is on correct device.
111
+ with torch.no_grad():
112
+ captions_embed = EMBB_MODEL.encode(captions, convert_to_tensor=True, show_progress_bar=False)
113
+
114
+ # Ensure dimensionality matches what decoder expects (Squeeze if necessary, though encode usually returns [B, D])
115
+ if len(captions_embed.shape) == 3:
116
+ captions_embed = captions_embed.squeeze()
117
+
118
+ return vision_features + [captions_embed]
119
+
120
+ @TERRATORCH_BACKBONE_REGISTRY.register
121
+ def terramind_v1_base_with_text(**kwargs):
122
+ return TerraMindWithText(terramind_kwargs=kwargs)
123
+
124
+ @TERRATORCH_DECODER_REGISTRY.register
125
+ class SimpleDecoder(nn.Module):
126
+ includes_head = True
127
+
128
+ def __init__(self, input_dim=768, num_classes=2, caption_dim=384):
129
+ super().__init__()
130
+ # Handle input_dim if passed as list (common in TerraTorch)
131
+ dim = input_dim[0] if isinstance(input_dim, (list, tuple)) else input_dim
132
+
133
+ self.image_conv = nn.Sequential(
134
+ nn.Conv2d(dim, 512, kernel_size=3, padding=1),
135
+ nn.BatchNorm2d(512),
136
+ nn.ReLU(inplace=True),
137
+ nn.Dropout2d(0.3),
138
+ nn.Conv2d(512, 256, kernel_size=3, padding=1),
139
+ nn.BatchNorm2d(256),
140
+ nn.ReLU(inplace=True),
141
+ nn.Dropout2d(0.3)
142
+ )
143
+
144
+ self.caption_mlp = nn.Sequential(
145
+ nn.Linear(caption_dim, 512),
146
+ nn.ReLU(inplace=True),
147
+ nn.Dropout(0.3),
148
+ nn.Linear(512, 256),
149
+ nn.ReLU(inplace=True),
150
+ nn.Dropout(0.3)
151
+ )
152
+
153
+ self.cross_attention = nn.MultiheadAttention(
154
+ embed_dim=256, num_heads=8, dropout=0.1, batch_first=True
155
+ )
156
+
157
+ self.fusion_conv = nn.Sequential(
158
+ nn.Conv2d(512, 256, kernel_size=3, padding=1),
159
+ nn.BatchNorm2d(256),
160
+ nn.ReLU(inplace=True),
161
+ nn.Dropout2d(0.3),
162
+ nn.Conv2d(256, 128, kernel_size=3, padding=1),
163
+ nn.BatchNorm2d(128),
164
+ nn.ReLU(inplace=True),
165
+ nn.Dropout2d(0.3)
166
+ )
167
+
168
+ self.conv_head = nn.Sequential(
169
+ nn.Conv2d(128, 64, kernel_size=3, padding=1),
170
+ nn.BatchNorm2d(64),
171
+ nn.ReLU(inplace=True),
172
+ nn.Dropout2d(0.3),
173
+ nn.Conv2d(64, 1, kernel_size=1)
174
+ )
175
+
176
+ self.out_channels = 1
177
+
178
+ def forward(self, features: list[torch.Tensor]) -> torch.Tensor:
179
+ # features list contains: [vision_feat_0, ..., vision_feat_11, caption_embed]
180
+ caption_embed = features[-1] # [B, 384]
181
+ image_features = features[:12]
182
+
183
+ # Average vision tokens
184
+ x = torch.stack(image_features, dim=1).mean(dim=1) # [B, 196, 768]
185
+
186
+ B, N, C = x.shape
187
+ H = W = int(N ** 0.5)
188
+
189
+ x = x.permute(0, 2, 1).view(B, C, H, W) # [B, 768, 14, 14]
190
+ img_features = self.image_conv(x) # [B, 256, 14, 14]
191
+
192
+ # Ensure caption embed has batch dim
193
+ if caption_embed.dim() == 1:
194
+ caption_embed = caption_embed.unsqueeze(0)
195
+
196
+ caption_features = self.caption_mlp(caption_embed) # [B, 256]
197
+
198
+ # Expand caption to spatial dims
199
+ caption_spatial = caption_features.unsqueeze(-1).unsqueeze(-1)
200
+ caption_spatial = caption_spatial.expand(B, -1, H, W) # [B, 256, 14, 14]
201
+
202
+ # Fuse
203
+ fused_features = torch.cat([img_features, caption_spatial], dim=1) # [B, 512, 14, 14]
204
+ fused = self.fusion_conv(fused_features) # [B, 128, 14, 14]
205
+
206
+ output = self.conv_head(fused) # [B, 1, 14, 14]
207
+ return output
208
+
209
+ # --- Helper Classes ---
210
+
211
+ class MetricTracker:
212
+ def __init__(self):
213
+ self.reset()
214
+
215
+ def reset(self):
216
+ self.all_targets = []
217
+ self.all_predictions = []
218
+ self.total_loss = 0.0
219
+ self.steps = 0
220
+
221
+ def update(self, loss: float, targets: torch.Tensor, probabilities: torch.Tensor):
222
+ self.total_loss += loss
223
+ self.steps += 1
224
+ self.all_targets.extend(torch.argmax(targets, dim=1).detach().cpu().numpy())
225
+ self.all_predictions.extend(torch.argmax(probabilities, dim=1).detach().cpu().numpy())
226
+
227
+ def compute(self) -> Dict[str, float]:
228
+ if not self.all_targets:
229
+ return {}
230
+
231
+ tn, fp, fn, tp = confusion_matrix(self.all_targets, self.all_predictions, labels=[0, 1]).ravel()
232
+
233
+ return {
234
+ "Loss": self.total_loss / max(self.steps, 1),
235
+ "Accuracy": accuracy_score(self.all_targets, self.all_predictions),
236
+ "Specificity": tn / (tn + fp) if (tn + fp) != 0 else 0.0,
237
+ "Sensitivity": recall_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0),
238
+ "F1": f1_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0),
239
+ "MCC": matthews_corrcoef(self.all_targets, self.all_predictions),
240
+ }
241
+
242
+ class MethaneTextTrainer:
243
+ def __init__(self, args: argparse.Namespace):
244
+ self.args = args
245
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
246
+ self.save_dir = Path(args.save_dir) / f'fold{args.test_fold}'
247
+ self.save_dir.mkdir(parents=True, exist_ok=True)
248
+
249
+ self.model = self._init_model()
250
+ self.optimizer, self.scheduler = self._init_optimizer()
251
+ self.criterion = self.task.criterion
252
+ self.best_val_loss = float('inf')
253
+
254
+ logger.info(f"Trainer initialized on device: {self.device}")
255
+
256
+ def _init_model(self) -> nn.Module:
257
+ model_args = dict(
258
+ backbone="terramind_v1_base_with_text",
259
+ backbone_pretrained=True,
260
+ backbone_modalities=["S2L2A"],
261
+ backbone_merge_method="mean",
262
+ num_classes=2,
263
+ head_dropout=0.3,
264
+ decoder="SimpleDecoder",
265
+ )
266
+
267
+ self.task = ClassificationTask(
268
+ model_args=model_args,
269
+ model_factory="EncoderDecoderFactory",
270
+ loss="ce",
271
+ lr=self.args.lr,
272
+ ignore_index=-1,
273
+ optimizer="AdamW",
274
+ optimizer_hparams={"weight_decay": self.args.weight_decay},
275
+ )
276
+ self.task.configure_models()
277
+ self.task.configure_losses()
278
+ return self.task.model.to(self.device)
279
+
280
+ def _init_optimizer(self):
281
+ optimizer = optim.AdamW(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
282
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, verbose=True)
283
+ return optimizer, scheduler
284
+
285
+ def run_epoch(self, dataloader: DataLoader, stage: str = "train") -> Dict[str, float]:
286
+ is_train = stage == "train"
287
+ self.model.train() if is_train else self.model.eval()
288
+ tracker = MetricTracker()
289
+
290
+ with torch.set_grad_enabled(is_train):
291
+ pbar = tqdm(dataloader, desc=f" {stage.capitalize()}", leave=False)
292
+ for batch in pbar:
293
+ # Prepare Inputs
294
+ inputs = batch['S2L2A'].to(self.device)
295
+ captions = batch['caption'] # List of strings
296
+ targets = batch['label'].to(self.device)
297
+
298
+ # Forward Pass (Note: passing captions explicitly)
299
+ # The Task wrapper might expect x dict, but our custom backbone forward handles 'captions'
300
+ outputs = self.model(x={"S2L2A": inputs}, captions=captions)
301
+ probabilities = torch.softmax(outputs.output, dim=1)
302
+ loss = self.criterion(probabilities, targets)
303
+
304
+ if is_train:
305
+ self.optimizer.zero_grad()
306
+ loss.backward()
307
+ self.optimizer.step()
308
+
309
+ tracker.update(loss.item(), targets, probabilities)
310
+ pbar.set_postfix(loss=f"{loss.item():.4f}")
311
+
312
+ return tracker.compute()
313
+
314
+ def log_to_csv(self, epoch: int, train_metrics: Dict, val_metrics: Dict):
315
+ csv_path = self.save_dir / 'train_val_metrics.csv'
316
+ headers = ['Epoch'] + [f'Train_{k}' for k in train_metrics.keys()] + [f'Val_{k}' for k in val_metrics.keys()]
317
+
318
+ with open(csv_path, mode='a', newline='') as f:
319
+ writer = csv.writer(f)
320
+ if not csv_path.exists():
321
+ writer.writerow(headers)
322
+ writer.writerow([epoch] + list(train_metrics.values()) + list(val_metrics.values()))
323
+
324
+ def fit(self, train_loader: DataLoader, val_loader: DataLoader):
325
+ logger.info(f"Starting training for {self.args.epochs} epochs...")
326
+ start_time = time.time()
327
+
328
+ for epoch in range(1, self.args.epochs + 1):
329
+ logger.info(f"Epoch {epoch}/{self.args.epochs}")
330
+
331
+ train_metrics = self.run_epoch(train_loader, stage="train")
332
+ val_metrics = self.run_epoch(val_loader, stage="validate")
333
+
334
+ self.scheduler.step(val_metrics['Loss'])
335
+ self.log_to_csv(epoch, train_metrics, val_metrics)
336
+
337
+ logger.info(f"Train Loss: {train_metrics['Loss']:.4f} | Val Loss: {val_metrics['Loss']:.4f} | Val F1: {val_metrics['F1']:.4f}")
338
+
339
+ if val_metrics['Loss'] < self.best_val_loss:
340
+ self.best_val_loss = val_metrics['Loss']
341
+ torch.save(self.model.state_dict(), self.save_dir / "best_model.pth")
342
+ logger.info(f"--> New best model saved")
343
+
344
+ torch.save(self.model.state_dict(), self.save_dir / "final_model.pth")
345
+ logger.info(f"Training finished in {time.time() - start_time:.2f}s")
346
+
347
+ # --- Data Utilities ---
348
+
349
+ def read_captions(json_path: Path, captions_dict: Dict) -> Dict:
350
+ """Reads captions from JSON and populates dictionary."""
351
+ if not json_path.exists():
352
+ logger.warning(f"Caption file not found: {json_path}")
353
+ return captions_dict
354
+
355
+ try:
356
+ with open(json_path, "r", encoding="utf-8") as file:
357
+ data = json.load(file)
358
+
359
+ for file_path_str, text_list in data.items():
360
+ if text_list and isinstance(text_list, list) and text_list[0]:
361
+ text_content = text_list[0][0]
362
+ caption_start = text_content.find("CAPTION:")
363
+ if caption_start != -1:
364
+ caption = text_content[caption_start + len("CAPTION:"):].strip()
365
+ # Extract folder name (assumes specific directory structure from notebook)
366
+ # "path\\to\\folder\\image.ext" -> "folder"
367
+ path_parts = file_path_str.replace("\\", "/").split("/")
368
+ if len(path_parts) >= 2:
369
+ last_directory = path_parts[-2]
370
+ captions_dict[last_directory] = caption
371
+ except Exception as e:
372
+ logger.error(f"Error reading captions {json_path}: {e}")
373
+
374
+ return captions_dict
375
+
376
+ def get_paths_for_fold(excel_file: str, folds: List[int]) -> List[str]:
377
+ df = pd.read_excel(excel_file)
378
+ df_filtered = df[df['Fold'].isin(folds)]
379
+ return df_filtered['Filename'].tolist()
380
+
381
+ def get_data_loaders(args) -> Tuple[DataLoader, DataLoader]:
382
+ # 1. Load Captions
383
+ captions_dict = {}
384
+ captions_dict = read_captions(Path(args.methane_captions), captions_dict)
385
+ captions_dict = read_captions(Path(args.no_methane_captions), captions_dict)
386
+ logger.info(f"Loaded {len(captions_dict)} captions.")
387
+
388
+ # 2. Get File Paths
389
+ all_folds = range(1, args.num_folds + 1)
390
+ train_pool_folds = [f for f in all_folds if f != args.test_fold]
391
+ paths = get_paths_for_fold(args.excel_file, train_pool_folds)
392
+
393
+ # 3. Split
394
+ train_paths, val_paths = train_test_split(paths, test_size=0.2, random_state=args.seed)
395
+ logger.info(f"Train: {len(train_paths)}, Val: {len(val_paths)}")
396
+
397
+ # 4. DataModule
398
+ datamodule = MethaneTextDataModule(
399
+ data_root=args.root_dir,
400
+ paths=paths, # Initial dummy
401
+ captions=captions_dict,
402
+ train_transform=get_training_transforms(),
403
+ batch_size=args.batch_size,
404
+ )
405
+
406
+ # Train Loader
407
+ datamodule.paths = train_paths
408
+ datamodule.setup(stage="train")
409
+ train_loader = datamodule.train_dataloader()
410
+
411
+ # Val Loader
412
+ datamodule.paths = val_paths
413
+ datamodule.setup(stage="validate")
414
+ val_loader = datamodule.val_dataloader()
415
+
416
+ return train_loader, val_loader
417
+
418
+ # --- Main Execution ---
419
+
420
+ def parse_args():
421
+ parser = argparse.ArgumentParser(description="Methane Text-Multimodal Training")
422
+
423
+ # Data Paths
424
+ parser.add_argument('--root_dir', type=str, required=True, help='Root directory for images')
425
+ parser.add_argument('--excel_file', type=str, required=True, help='Path to Summary Excel')
426
+ parser.add_argument('--methane_captions', type=str, required=True, help='Path to Methane JSON captions')
427
+ parser.add_argument('--no_methane_captions', type=str, required=True, help='Path to No-Methane JSON captions')
428
+ parser.add_argument('--save_dir', type=str, default='./checkpoints', help='Output directory')
429
+
430
+ # Hyperparameters
431
+ parser.add_argument('--epochs', type=int, default=100)
432
+ parser.add_argument('--batch_size', type=int, default=4)
433
+ parser.add_argument('--lr', type=float, default=5e-5)
434
+ parser.add_argument('--weight_decay', type=float, default=0.05)
435
+ parser.add_argument('--num_folds', type=int, default=5)
436
+ parser.add_argument('--test_fold', type=int, default=2)
437
+ parser.add_argument('--seed', type=int, default=42)
438
+
439
+ return parser.parse_args()
440
+
441
+ if __name__ == "__main__":
442
+ args = parse_args()
443
+ set_seed(args.seed)
444
+
445
+ train_loader, val_loader = get_data_loaders(args)
446
+
447
+ trainer = MethaneTextTrainer(args)
448
+ trainer.fit(train_loader, val_loader)
intuition1_classification_finetuning/config/methane_simulated_datamodule.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import albumentations as A
3
+ from sklearn.model_selection import train_test_split
4
+ from torch.utils.data import DataLoader
5
+ from torchgeo.datamodules import NonGeoDataModule
6
+ from methane_simulated_dataset import MethaneSimulatedDataset
7
+
8
+ class MethaneSimulatedDataModule(NonGeoDataModule):
9
+ def __init__(
10
+ self,
11
+ data_root: str,
12
+ excel_file: str,
13
+ batch_size: int = 8,
14
+ num_workers: int = 0,
15
+ val_split: float = 0.2,
16
+ seed: int = 42,
17
+ test_fold: int = 4, # Default test fold from your script
18
+ num_folds: int = 5,
19
+ **kwargs
20
+ ):
21
+ super().__init__(MethaneSimulatedDataset, batch_size, num_workers, **kwargs)
22
+
23
+ self.data_root = data_root
24
+ self.excel_file = excel_file
25
+ self.batch_size = batch_size
26
+ self.num_workers = num_workers
27
+ self.val_split = val_split
28
+ self.seed = seed
29
+ self.test_fold = test_fold
30
+ self.num_folds = num_folds
31
+
32
+ self.train_paths = []
33
+ self.val_paths = []
34
+
35
+ def _get_training_transforms(self):
36
+ return A.Compose([
37
+ A.ElasticTransform(p=0.25),
38
+ A.RandomRotate90(p=0.5),
39
+ A.Flip(p=0.5),
40
+ A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5)
41
+ ])
42
+
43
+ def _get_simulated_paths(self, paths):
44
+ """Logic to rename files to I1/TOA format"""
45
+ simulated_paths = []
46
+ for path in paths:
47
+ try:
48
+ tokens = path.split('_')
49
+ if len(tokens) >= 5:
50
+ simulated_path = f"{tokens[0]}_toarefl_{tokens[3]}_{tokens[4]}"
51
+ simulated_paths.append(simulated_path)
52
+ else:
53
+ simulated_paths.append(path)
54
+ except Exception:
55
+ simulated_paths.append(path)
56
+ return simulated_paths
57
+
58
+ def setup(self, stage: str = None):
59
+ # 1. Read Excel
60
+ try:
61
+ df = pd.read_excel(self.excel_file)
62
+ except Exception as e:
63
+ raise RuntimeError(f"Failed to load excel: {e}")
64
+
65
+ # 2. Filter Folds (Exclude test_fold)
66
+ all_folds = list(range(1, self.num_folds + 1))
67
+ train_pool_folds = [f for f in all_folds if f != self.test_fold]
68
+
69
+ df_filtered = df[df['Fold'].isin(train_pool_folds)]
70
+ raw_paths = df_filtered['Filename'].tolist()
71
+
72
+ # 3. Apply Path Renaming Logic
73
+ paths = self._get_simulated_paths(raw_paths)
74
+
75
+ # 4. Train/Val Split
76
+ self.train_paths, self.val_paths = train_test_split(
77
+ paths,
78
+ test_size=self.val_split,
79
+ random_state=self.seed
80
+ )
81
+
82
+ # 5. Instantiate Datasets
83
+ if stage in ("fit", "train"):
84
+ self.train_dataset = MethaneSimulatedDataset(
85
+ root_dir=self.data_root,
86
+ excel_file=self.excel_file,
87
+ paths=self.train_paths,
88
+ transform=self._get_training_transforms(),
89
+ )
90
+
91
+ if stage in ("fit", "validate", "val"):
92
+ self.val_dataset = MethaneSimulatedDataset(
93
+ root_dir=self.data_root,
94
+ excel_file=self.excel_file,
95
+ paths=self.val_paths,
96
+ transform=None,
97
+ )
98
+
99
+ def train_dataloader(self):
100
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,
101
+ num_workers=self.num_workers, drop_last=True)
102
+
103
+ def val_dataloader(self):
104
+ return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False,
105
+ num_workers=self.num_workers, drop_last=True)
106
+
107
+ # def on_after_batch_transfer(self, batch, dataloader_idx):
108
+ # # 1. Run TorchGeo default (expects 'image')
109
+ # batch = super().on_after_batch_transfer(batch, dataloader_idx)
110
+
111
+ # # 2. Wrap into TerraMind format {'S2L2A': ...}
112
+ # if 'image' in batch:
113
+ # s2_data = batch['image']
114
+ # batch['image'] = {'S2L2A': s2_data}
115
+
116
+ # return batch
intuition1_classification_finetuning/config/methane_simulated_dataset.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import rasterio
3
+ import torch
4
+ from torchgeo.datasets import NonGeoDataset
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+
8
+ class MethaneSimulatedDataset(NonGeoDataset):
9
+ def __init__(self, root_dir, excel_file, paths, transform=None):
10
+ super().__init__()
11
+ self.root_dir = root_dir
12
+ self.transform = transform
13
+ self.data_paths = []
14
+
15
+ # Collect paths
16
+ for folder_name in paths:
17
+ subdir_path = os.path.join(root_dir, folder_name)
18
+ if os.path.isdir(subdir_path):
19
+ # Note: Filenames here seem to match the folder name based on your script
20
+ label_path = os.path.join(subdir_path, folder_name + '_mask.tif')
21
+ scube_path = os.path.join(subdir_path, folder_name + '_hsi.dat')
22
+
23
+ if os.path.exists(label_path) and os.path.exists(scube_path):
24
+ self.data_paths.append((label_path, scube_path))
25
+
26
+ def __len__(self):
27
+ return len(self.data_paths)
28
+
29
+ def __getitem__(self, idx):
30
+ label_path, scube_path = self.data_paths[idx]
31
+
32
+ # Load label
33
+ with rasterio.open(label_path) as label_src:
34
+ label_image = label_src.read()
35
+
36
+ # Load sCube (I1/TOA data)
37
+ with rasterio.open(scube_path) as scube_src:
38
+ scube_image = scube_src.read()
39
+ # Read only first 12 bands
40
+ scube_image = scube_image[:12, :, :]
41
+
42
+ # Convert to Tensors
43
+ scube_tensor = torch.from_numpy(scube_image).float()
44
+ label_tensor = torch.from_numpy(label_image).float()
45
+
46
+ # Resize
47
+ scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
48
+ label_tensor = F.interpolate(label_tensor.unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0)
49
+
50
+ label_tensor = label_tensor.clip(0, 1)
51
+ scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0)
52
+
53
+ # Convert labels to binary index (0 or 1)
54
+ contains_methane = (label_tensor > 0).any().long()
55
+
56
+ # Apply transformations
57
+ if self.transform:
58
+ # Albumentations expects [H, W, C]
59
+ transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0)))
60
+ scube_tensor = transformed['image'].transpose(2, 0, 1)
61
+
62
+ return {
63
+ 'image': scube_tensor, # <--- Named 'image' for TorchGeo
64
+ 'label': contains_methane, # <--- Index for CrossEntropy
65
+ 'sample': scube_path
66
+ }
intuition1_classification_finetuning/config/train.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed_everything: 42
2
+
3
+ trainer:
4
+ accelerator: auto
5
+ strategy: auto
6
+ devices: 1
7
+ max_epochs: 100
8
+ default_root_dir: ./checkpoints_i1
9
+ callbacks:
10
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
11
+ init_args:
12
+ monitor: val/loss
13
+ mode: min
14
+ save_top_k: 1
15
+ filename: "best_model"
16
+ save_last: true
17
+ - class_path: lightning.pytorch.callbacks.LearningRateMonitor
18
+ init_args:
19
+ logging_interval: epoch
20
+
21
+ model:
22
+ class_path: terratorch.tasks.ClassificationTask
23
+ init_args:
24
+ model_factory: EncoderDecoderFactory
25
+ loss: ce
26
+ ignore_index: -1
27
+ lr: 1.0e-5 # Top-level LR
28
+
29
+ optimizer: AdamW
30
+ optimizer_hparams:
31
+ weight_decay: 0.05
32
+
33
+ scheduler: ReduceLROnPlateau
34
+ scheduler_hparams:
35
+ mode: min
36
+ patience: 5
37
+
38
+ model_args:
39
+ backbone: terramind_v1_base
40
+ backbone_pretrained: true
41
+ backbone_modalities:
42
+ - S2L2A
43
+ backbone_merge_method: mean
44
+
45
+ decoder: UperNetDecoder
46
+ decoder_scale_modules: true
47
+ decoder_channels: 256
48
+ num_classes: 2
49
+ head_dropout: 0.3
50
+
51
+ necks:
52
+ - name: ReshapeTokensToImage
53
+ remove_cls_token: false
54
+ - name: SelectIndices
55
+ indices: [2, 5, 8, 11]
56
+
57
+ data:
58
+ class_path: methane_simulated_datamodule.MethaneSimulatedDataModule
59
+ init_args:
60
+ data_root: # Place the data root here
61
+ excel_file: ../../Methane_benchmark_patches_summary_v3.xlsx # Update this!
62
+ batch_size: 8
63
+ val_split: 0.2
64
+ seed: 42
65
+ test_fold: 4
66
+ num_folds: 5
intuition1_classification_finetuning/script/methane_simulated_datamodule.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import rasterio
3
+ import torch
4
+ from torchgeo.datasets import NonGeoDataset
5
+ from torch.utils.data import DataLoader
6
+ from torchgeo.datamodules import NonGeoDataModule
7
+ from methane_simulated_dataset import MethaneSimulatedDataset
8
+ # from methane_classification_dataset import MethaneClassificationDataset
9
+
10
+ class MethaneSimulatedDataModule(NonGeoDataModule):
11
+ """
12
+ A DataModule for handling MethaneClassificationDataset
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ data_root: str,
18
+ excel_file: str,
19
+ paths: list,
20
+ batch_size: int = 8,
21
+ num_workers: int = 0,
22
+ train_transform: callable = None,
23
+ val_transform: callable = None,
24
+ test_transform: callable = None,
25
+ **kwargs
26
+ ):
27
+ super().__init__(MethaneSimulatedDataset, batch_size, num_workers, **kwargs)
28
+
29
+ self.data_root = data_root
30
+ self.excel_file = excel_file
31
+ self.paths = paths
32
+ self.train_transform = train_transform
33
+ self.val_transform = val_transform
34
+ self.test_transform = test_transform
35
+
36
+ def setup(self, stage: str = None):
37
+ if stage in ("fit", "train"):
38
+ self.train_dataset = MethaneSimulatedDataset(
39
+ root_dir=self.data_root,
40
+ excel_file=self.excel_file,
41
+ paths=self.paths,
42
+ transform=self.train_transform,
43
+ )
44
+ if stage in ("fit", "validate", "val"):
45
+ self.val_dataset = MethaneSimulatedDataset(
46
+ root_dir=self.data_root,
47
+ excel_file=self.excel_file,
48
+ paths=self.paths,
49
+ transform=self.val_transform,
50
+ )
51
+ if stage in ("test", "predict"):
52
+ self.test_dataset = MethaneSimulatedDataset(
53
+ root_dir=self.data_root,
54
+ excel_file=self.excel_file,
55
+ paths=self.paths,
56
+ transform=self.test_transform,
57
+ )
58
+
59
+ def train_dataloader(self):
60
+ return DataLoader(
61
+ self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True
62
+ )
63
+
64
+ def val_dataloader(self):
65
+ return DataLoader(
66
+ self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
67
+ )
68
+
69
+ def test_dataloader(self):
70
+ return DataLoader(
71
+ self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
72
+ )
intuition1_classification_finetuning/script/methane_simulated_dataset.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import rasterio
3
+ import torch
4
+ from torchgeo.datasets import NonGeoDataset
5
+ from torch.utils.data import DataLoader
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ def min_max_normalize(data, new_min=0, new_max=1):
11
+ data = np.array(data, dtype=np.float32) # Convert to NumPy array
12
+
13
+ # Handle NaN, Inf values
14
+ data = np.nan_to_num(data, nan=np.nanmin(data), posinf=np.max(data), neginf=np.min(data))
15
+
16
+ old_min, old_max = np.min(data), np.max(data)
17
+
18
+ if old_max == old_min: # Prevent division by zero
19
+ return np.full_like(data, new_min, dtype=np.float32) # Uniform array
20
+
21
+ return (data - old_min) / (old_max - old_min + 1e-10) * (new_max - new_min) + new_min
22
+
23
+ class MethaneSimulatedDataset(NonGeoDataset):
24
+ def __init__(self, root_dir, excel_file, paths, transform=None):
25
+ super().__init__()
26
+ self.root_dir = root_dir
27
+ self.transform = transform
28
+ self.data_paths = []
29
+
30
+ # Collect paths for labelbinary.tif and sCube.tif in selected folders
31
+ for folder_name in paths:
32
+ subdir_path = os.path.join(root_dir, folder_name)
33
+ if os.path.isdir(subdir_path):
34
+ label_path = os.path.join(subdir_path, folder_name + '_mask.tif')
35
+ scube_path = os.path.join(subdir_path, folder_name + '_hsi.dat')
36
+
37
+ if os.path.exists(label_path) and os.path.exists(scube_path):
38
+ self.data_paths.append((label_path, scube_path))
39
+
40
+ def __len__(self):
41
+ return len(self.data_paths)
42
+
43
+ def __getitem__(self, idx):
44
+ label_path, scube_path = self.data_paths[idx]
45
+
46
+ # Load the label image (single band)
47
+ with rasterio.open(label_path) as label_src:
48
+ label_image = label_src.read() # Shape: [512, 512]
49
+
50
+
51
+ # Load the sCube image (multi-band), drop the first band
52
+ with rasterio.open(scube_path) as scube_src:
53
+ scube_image = scube_src.read() # Shape: [13, 512, 512]
54
+
55
+ # Read only the first 12 bands for testing purposes
56
+ # Map the bands later on
57
+ scube_image = scube_image[:12, :, :]
58
+
59
+ # Convert to PyTorch tensors
60
+ scube_tensor = torch.from_numpy(scube_image).float() # Shape: [12, 512, 512]
61
+ label_tensor = torch.from_numpy(label_image).float() # Shape: [512, 512]
62
+
63
+ # Resize to [12, 224, 224] and [224, 224] respectively
64
+ scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
65
+ label_tensor = F.interpolate(label_tensor.unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0)
66
+
67
+ label_tensor = label_tensor.clip(0, 1) # Clip values to [0, 1]
68
+ scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0) # Replace NaNs with 0
69
+ # normalized_tensor = min_max_normalize(scube_tensor)
70
+ # Convert labels to binary
71
+ contains_methane = (label_tensor > 0).any().long()
72
+
73
+ # Convert to one-hot encoding
74
+ one_hot_label = F.one_hot(contains_methane, num_classes=2).float()
75
+
76
+ # Apply transformations (if any)
77
+ if self.transform:
78
+ transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0)))
79
+ scube_tensor = transformed['image'].transpose(2, 0, 1) # Convert back to [C, H, W]
80
+
81
+ return {'S2L2A': scube_tensor, 'label': one_hot_label, 'sample': scube_path}
intuition1_classification_finetuning/script/train_simulated_I1.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import csv
4
+ import random
5
+ import warnings
6
+ import time
7
+ from pathlib import Path
8
+ from typing import Dict, List, Tuple, Any, Optional
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.optim as optim
15
+ import albumentations as A
16
+ from torch.utils.data import DataLoader
17
+ from tqdm import tqdm
18
+ from sklearn.model_selection import train_test_split
19
+ from sklearn.metrics import (
20
+ accuracy_score, recall_score, f1_score, matthews_corrcoef, confusion_matrix
21
+ )
22
+ from rasterio.errors import NotGeoreferencedWarning
23
+
24
+ # --- CRITICAL IMPORTS ---
25
+ import terramind
26
+ from terratorch.tasks import ClassificationTask
27
+
28
+ # Local Imports
29
+ from methane_simulated_datamodule import MethaneSimulatedDataModule
30
+
31
+ # --- Configuration & Setup ---
32
+
33
+ logging.basicConfig(
34
+ level=logging.INFO,
35
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
36
+ datefmt='%Y-%m-%d %H:%M:%S'
37
+ )
38
+ logger = logging.getLogger(__name__)
39
+
40
+ logging.getLogger("rasterio._env").setLevel(logging.ERROR)
41
+ warnings.simplefilter("ignore", NotGeoreferencedWarning)
42
+ warnings.filterwarnings("ignore", category=FutureWarning)
43
+
44
+ def set_seed(seed: int = 42):
45
+ random.seed(seed)
46
+ np.random.seed(seed)
47
+ torch.manual_seed(seed)
48
+ if torch.cuda.is_available():
49
+ torch.cuda.manual_seed_all(seed)
50
+
51
+ def get_training_transforms() -> A.Compose:
52
+ return A.Compose([
53
+ A.ElasticTransform(p=0.25),
54
+ A.RandomRotate90(p=0.5),
55
+ A.Flip(p=0.5),
56
+ A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5)
57
+ ])
58
+
59
+ # --- Path Utilities ---
60
+
61
+ def get_simulated_paths(paths: List[str]) -> List[str]:
62
+ """
63
+ Modifies filenames to match the I1/TOA naming convention.
64
+ Converts 'ang2015..._S2_...' -> 'ang2015..._toarefl_...'
65
+ """
66
+ simulated_paths = []
67
+ for path in paths:
68
+ try:
69
+ tokens = path.split('_')
70
+ # Logic: {ID}_toarefl_{Coord1}_{Coord2}
71
+ # Adjusts original filename tokens to target format
72
+ if len(tokens) >= 5:
73
+ simulated_path = f"{tokens[0]}_toarefl_{tokens[3]}_{tokens[4]}"
74
+ simulated_paths.append(simulated_path)
75
+ else:
76
+ simulated_paths.append(path)
77
+ except Exception as e:
78
+ logger.warning(f"Could not parse path {path}: {e}")
79
+ simulated_paths.append(path)
80
+ return simulated_paths
81
+
82
+ def get_paths_for_fold(excel_file: str, folds: List[int]) -> List[str]:
83
+ try:
84
+ df = pd.read_excel(excel_file)
85
+ df_filtered = df[df['Fold'].isin(folds)]
86
+ return df_filtered['Filename'].tolist()
87
+ except Exception as e:
88
+ logger.error(f"Error reading Excel file: {e}")
89
+ raise
90
+
91
+ # --- Helper Classes ---
92
+
93
+ class MetricTracker:
94
+ def __init__(self):
95
+ self.reset()
96
+
97
+ def reset(self):
98
+ self.all_targets = []
99
+ self.all_predictions = []
100
+ self.total_loss = 0.0
101
+ self.steps = 0
102
+
103
+ def update(self, loss: float, targets: torch.Tensor, probabilities: torch.Tensor):
104
+ self.total_loss += loss
105
+ self.steps += 1
106
+ self.all_targets.extend(torch.argmax(targets, dim=1).detach().cpu().numpy())
107
+ self.all_predictions.extend(torch.argmax(probabilities, dim=1).detach().cpu().numpy())
108
+
109
+ def compute(self) -> Dict[str, float]:
110
+ if not self.all_targets:
111
+ return {}
112
+
113
+ tn, fp, fn, tp = confusion_matrix(self.all_targets, self.all_predictions, labels=[0, 1]).ravel()
114
+
115
+ return {
116
+ "Loss": self.total_loss / max(self.steps, 1),
117
+ "Accuracy": accuracy_score(self.all_targets, self.all_predictions),
118
+ "Specificity": tn / (tn + fp) if (tn + fp) != 0 else 0.0,
119
+ "Sensitivity": recall_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0),
120
+ "F1": f1_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0),
121
+ "MCC": matthews_corrcoef(self.all_targets, self.all_predictions),
122
+ "TP": int(tp), "TN": int(tn), "FP": int(fp), "FN": int(fn)
123
+ }
124
+
125
+ class TrainerI1:
126
+ def __init__(self, args: argparse.Namespace):
127
+ self.args = args
128
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
129
+ self.save_dir = Path(args.save_dir) / f'fold{args.test_fold}'
130
+ self.save_dir.mkdir(parents=True, exist_ok=True)
131
+
132
+ self.model = self._init_model()
133
+ self.optimizer, self.scheduler = self._init_optimizer()
134
+ self.criterion = self.task.criterion
135
+ self.best_val_loss = float('inf')
136
+
137
+ logger.info(f"Trainer initialized on device: {self.device}")
138
+
139
+ def _init_model(self) -> nn.Module:
140
+ model_args = dict(
141
+ backbone="terramind_v1_base",
142
+ backbone_pretrained=True,
143
+ backbone_modalities=["S2L2A"],
144
+ backbone_merge_method="mean",
145
+ decoder="UperNetDecoder",
146
+ decoder_scale_modules=True,
147
+ decoder_channels=256,
148
+ num_classes=2,
149
+ head_dropout=0.3,
150
+ necks=[
151
+ {"name": "ReshapeTokensToImage", "remove_cls_token": False},
152
+ {"name": "SelectIndices", "indices": [2, 5, 8, 11]},
153
+ {"name": "LearnedInterpolateToPyramidal"},
154
+ ],
155
+ )
156
+
157
+ self.task = ClassificationTask(
158
+ model_args=model_args,
159
+ model_factory="EncoderDecoderFactory",
160
+ loss="ce",
161
+ lr=self.args.lr,
162
+ ignore_index=-1,
163
+ optimizer="AdamW",
164
+ optimizer_hparams={"weight_decay": self.args.weight_decay},
165
+ )
166
+ self.task.configure_models()
167
+ self.task.configure_losses()
168
+ return self.task.model.to(self.device)
169
+
170
+ def _init_optimizer(self):
171
+ optimizer = optim.AdamW(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
172
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, verbose=True)
173
+ return optimizer, scheduler
174
+
175
+ def run_epoch(self, dataloader: DataLoader, stage: str = "train") -> Dict[str, float]:
176
+ is_train = stage == "train"
177
+ self.model.train() if is_train else self.model.eval()
178
+ tracker = MetricTracker()
179
+
180
+ with torch.set_grad_enabled(is_train):
181
+ pbar = tqdm(dataloader, desc=f" {stage.capitalize()}", leave=False)
182
+ for batch in pbar:
183
+ inputs = batch['S2L2A'].to(self.device)
184
+ targets = batch['label'].to(self.device)
185
+
186
+ outputs = self.model(x={"S2L2A": inputs})
187
+ probabilities = torch.softmax(outputs.output, dim=1)
188
+ loss = self.criterion(probabilities, targets)
189
+
190
+ if is_train:
191
+ self.optimizer.zero_grad()
192
+ loss.backward()
193
+ self.optimizer.step()
194
+
195
+ tracker.update(loss.item(), targets, probabilities)
196
+ pbar.set_postfix(loss=f"{loss.item():.4f}")
197
+
198
+ return tracker.compute()
199
+
200
+ def fit(self, train_loader: DataLoader, val_loader: DataLoader):
201
+ logger.info(f"Starting training for {self.args.epochs} epochs...")
202
+ start_time = time.time()
203
+
204
+ # Initialize CSV logging
205
+ csv_path = self.save_dir / 'train_val_metrics.csv'
206
+ with open(csv_path, 'w', newline='') as f:
207
+ writer = csv.writer(f)
208
+ writer.writerow([
209
+ 'Epoch', 'Train_Loss', 'Train_F1', 'Train_Acc',
210
+ 'Val_Loss', 'Val_F1', 'Val_Acc', 'Val_Spec', 'Val_Sens'
211
+ ])
212
+
213
+ for epoch in range(1, self.args.epochs + 1):
214
+ logger.info(f"Epoch {epoch}/{self.args.epochs}")
215
+
216
+ train_metrics = self.run_epoch(train_loader, stage="train")
217
+ val_metrics = self.run_epoch(val_loader, stage="validate")
218
+
219
+ self.scheduler.step(val_metrics['Loss'])
220
+
221
+ # Log to CSV
222
+ with open(csv_path, 'a', newline='') as f:
223
+ writer = csv.writer(f)
224
+ writer.writerow([
225
+ epoch,
226
+ train_metrics.get('Loss'), train_metrics.get('F1'), train_metrics.get('Accuracy'),
227
+ val_metrics.get('Loss'), val_metrics.get('F1'), val_metrics.get('Accuracy'),
228
+ val_metrics.get('Specificity'), val_metrics.get('Sensitivity')
229
+ ])
230
+
231
+ logger.info(f"Train Loss: {train_metrics['Loss']:.4f} | Val Loss: {val_metrics['Loss']:.4f} | Val F1: {val_metrics['F1']:.4f}")
232
+
233
+ # Save Best Model
234
+ if val_metrics['Loss'] < self.best_val_loss:
235
+ self.best_val_loss = val_metrics['Loss']
236
+ torch.save(self.model.state_dict(), self.save_dir / "best_model.pth")
237
+ logger.info(f"--> New best model saved")
238
+
239
+ # Save Final Model
240
+ torch.save(self.model.state_dict(), self.save_dir / "final_model.pth")
241
+ logger.info(f"Training finished in {time.time() - start_time:.2f}s")
242
+
243
+ # --- Data Utilities ---
244
+
245
+ def get_data_loaders(args) -> Tuple[DataLoader, DataLoader]:
246
+ # 1. Determine Folds
247
+ all_folds = list(range(1, args.num_folds + 1))
248
+ train_pool_folds = [f for f in all_folds if f != args.test_fold]
249
+
250
+ # 2. Get Paths & Convert to TOA/I1 format
251
+ # Note: Using get_simulated_paths to transform names as done in the notebook
252
+ paths = get_paths_for_fold(args.excel_file, train_pool_folds)
253
+ paths = get_simulated_paths(paths)
254
+
255
+ # 3. Train/Val Split (80/20)
256
+ train_paths, val_paths = train_test_split(paths, test_size=0.2, random_state=args.seed)
257
+
258
+ logger.info(f"Data Split - Train: {len(train_paths)}, Val: {len(val_paths)} (Test Fold: {args.test_fold})")
259
+
260
+ # 4. Initialize DataModule
261
+ datamodule = MethaneSimulatedDataModule(
262
+ data_root=args.root_dir,
263
+ excel_file=args.excel_file,
264
+ batch_size=args.batch_size,
265
+ paths=paths, # Initial dummy
266
+ train_transform=get_training_transforms(),
267
+ val_transform=None,
268
+ )
269
+
270
+ # 5. Create Loaders
271
+ datamodule.paths = train_paths
272
+ datamodule.setup(stage="fit")
273
+ train_loader = datamodule.train_dataloader()
274
+
275
+ datamodule.paths = val_paths
276
+ datamodule.setup(stage="validate")
277
+ val_loader = datamodule.val_dataloader()
278
+
279
+ return train_loader, val_loader
280
+
281
+ # --- Main Execution ---
282
+
283
+ def parse_args():
284
+ parser = argparse.ArgumentParser(description="Methane I1 (TOA Refl) Training")
285
+
286
+ # Paths
287
+ parser.add_argument('--root_dir', type=str, required=True, help='Root directory for I1/TOA data')
288
+ parser.add_argument('--excel_file', type=str, required=True, help='Path to Summary Excel')
289
+ parser.add_argument('--save_dir', type=str, default='./checkpoints_i1', help='Output directory')
290
+
291
+ # Hyperparameters
292
+ parser.add_argument('--epochs', type=int, default=100)
293
+ parser.add_argument('--batch_size', type=int, default=1)
294
+ parser.add_argument('--lr', type=float, default=1e-5)
295
+ parser.add_argument('--weight_decay', type=float, default=0.05)
296
+ parser.add_argument('--num_folds', type=int, default=5)
297
+ parser.add_argument('--test_fold', type=int, default=4, help='Fold ID to hold out')
298
+ parser.add_argument('--seed', type=int, default=42)
299
+
300
+ return parser.parse_args()
301
+
302
+ if __name__ == "__main__":
303
+ args = parse_args()
304
+ set_seed(args.seed)
305
+
306
+ train_loader, val_loader = get_data_loaders(args)
307
+
308
+ trainer = TrainerI1(args)
309
+ trainer.fit(train_loader, val_loader)
sentinel2_classification_finetuning/config/methane_simulated_datamodule.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import albumentations as A
3
+ from sklearn.model_selection import train_test_split
4
+ from torch.utils.data import DataLoader
5
+ from torchgeo.datamodules import NonGeoDataModule
6
+ from methane_simulated_dataset import MethaneSimulatedDataset
7
+
8
+ class MethaneSimulatedDataModule(NonGeoDataModule):
9
+ def __init__(
10
+ self,
11
+ data_root: str,
12
+ excel_file: str,
13
+ batch_size: int = 8,
14
+ num_workers: int = 0,
15
+ val_split: float = 0.2,
16
+ seed: int = 42,
17
+ test_fold: int = 4,
18
+ num_folds: int = 5,
19
+ sim_tag: str = "toarefl", # <--- New arg for 'toarefl'/'boarefl'
20
+ **kwargs
21
+ ):
22
+ super().__init__(MethaneSimulatedDataset, batch_size, num_workers, **kwargs)
23
+
24
+ self.data_root = data_root
25
+ self.excel_file = excel_file
26
+ self.batch_size = batch_size
27
+ self.num_workers = num_workers
28
+ self.val_split = val_split
29
+ self.seed = seed
30
+ self.test_fold = test_fold
31
+ self.num_folds = num_folds
32
+ self.sim_tag = sim_tag
33
+
34
+ self.train_paths = []
35
+ self.val_paths = []
36
+
37
+ def _get_training_transforms(self):
38
+ return A.Compose([
39
+ A.ElasticTransform(p=0.25),
40
+ A.RandomRotate90(p=0.5),
41
+ A.Flip(p=0.5),
42
+ A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5)
43
+ ])
44
+
45
+ def _get_simulated_paths(self, paths):
46
+ """Internal logic to rename files based on sim_tag"""
47
+ simulated_paths = []
48
+ for path in paths:
49
+ try:
50
+ tokens = path.split('_')
51
+ if len(tokens) >= 5:
52
+ # Logic: {ID}_{tag}_{Coord1}_{Coord2}
53
+ simulated_path = f"{tokens[0]}_{self.sim_tag}_{tokens[3]}_{tokens[4]}"
54
+ simulated_paths.append(simulated_path)
55
+ else:
56
+ simulated_paths.append(path)
57
+ except Exception:
58
+ simulated_paths.append(path)
59
+ return simulated_paths
60
+
61
+ def setup(self, stage: str = None):
62
+ # 1. Read Excel
63
+ try:
64
+ df = pd.read_excel(self.excel_file)
65
+ except Exception as e:
66
+ raise RuntimeError(f"Failed to load excel: {e}")
67
+
68
+ # 2. Filter Folds (Exclude test_fold)
69
+ all_folds = list(range(1, self.num_folds + 1))
70
+ train_pool_folds = [f for f in all_folds if f != self.test_fold]
71
+
72
+ df_filtered = df[df['Fold'].isin(train_pool_folds)]
73
+ raw_paths = df_filtered['Filename'].tolist()
74
+
75
+ # 3. Apply Path Renaming Logic
76
+ paths = self._get_simulated_paths(raw_paths)
77
+
78
+ # 4. Train/Val Split
79
+ self.train_paths, self.val_paths = train_test_split(
80
+ paths,
81
+ test_size=self.val_split,
82
+ random_state=self.seed
83
+ )
84
+
85
+ # 5. Instantiate Datasets
86
+ if stage in ("fit", "train"):
87
+ self.train_dataset = MethaneSimulatedDataset(
88
+ root_dir=self.data_root,
89
+ excel_file=self.excel_file,
90
+ paths=self.train_paths,
91
+ transform=self._get_training_transforms(),
92
+ )
93
+
94
+ if stage in ("fit", "validate", "val"):
95
+ self.val_dataset = MethaneSimulatedDataset(
96
+ root_dir=self.data_root,
97
+ excel_file=self.excel_file,
98
+ paths=self.val_paths,
99
+ transform=None,
100
+ )
101
+
102
+ def train_dataloader(self):
103
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,
104
+ num_workers=self.num_workers, drop_last=True)
105
+
106
+ def val_dataloader(self):
107
+ return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False,
108
+ num_workers=self.num_workers, drop_last=True)
109
+
110
+ def on_after_batch_transfer(self, batch, dataloader_idx):
111
+ # 1. Run TorchGeo default (expects 'image')
112
+ batch = super().on_after_batch_transfer(batch, dataloader_idx)
113
+
114
+ # 2. Wrap into TerraMind format {'S2L2A': ...}
115
+ if 'image' in batch:
116
+ s2_data = batch['image']
117
+ batch['image'] = {'S2L2A': s2_data}
118
+
119
+ return batch
sentinel2_classification_finetuning/config/methane_simulated_dataset.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import rasterio
3
+ import torch
4
+ from torchgeo.datasets import NonGeoDataset
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+
8
+ class MethaneSimulatedDataset(NonGeoDataset):
9
+ def __init__(self, root_dir, excel_file, paths, transform=None):
10
+ super().__init__()
11
+ self.root_dir = root_dir
12
+ self.transform = transform
13
+ self.data_paths = []
14
+
15
+ # Collect paths
16
+ for folder_name in paths:
17
+ subdir_path = os.path.join(root_dir, folder_name)
18
+ if os.path.isdir(subdir_path):
19
+ # Construct paths based on folder name
20
+ label_path = os.path.join(subdir_path, folder_name + '_mask.tif')
21
+ scube_path = os.path.join(subdir_path, folder_name + '_hsi.dat')
22
+
23
+ if os.path.exists(label_path) and os.path.exists(scube_path):
24
+ self.data_paths.append((label_path, scube_path))
25
+
26
+ def __len__(self):
27
+ return len(self.data_paths)
28
+
29
+ def __getitem__(self, idx):
30
+ label_path, scube_path = self.data_paths[idx]
31
+
32
+ # Load label
33
+ with rasterio.open(label_path) as label_src:
34
+ label_image = label_src.read()
35
+
36
+ # Load sCube (Try explicit ENVI driver first for .dat files)
37
+ try:
38
+ with rasterio.open(scube_path, driver='ENVI') as scube_src:
39
+ scube_image = scube_src.read()
40
+ scube_image = scube_image[:12, :, :] # Read first 12 bands
41
+ except Exception:
42
+ # Fallback if driver auto-detection is needed
43
+ with rasterio.open(scube_path) as scube_src:
44
+ scube_image = scube_src.read()
45
+ scube_image = scube_image[:12, :, :]
46
+
47
+ # Convert to Tensors
48
+ scube_tensor = torch.from_numpy(scube_image).float()
49
+ label_tensor = torch.from_numpy(label_image).float()
50
+
51
+ # Resize
52
+ scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
53
+ label_tensor = F.interpolate(label_tensor.unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0)
54
+
55
+ label_tensor = label_tensor.clip(0, 1)
56
+ scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0)
57
+
58
+ # Convert labels to binary index (0 or 1)
59
+ contains_methane = (label_tensor > 0).any().long()
60
+
61
+ # Apply transformations
62
+ if self.transform:
63
+ transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0)))
64
+ scube_tensor = transformed['image'].transpose(2, 0, 1)
65
+
66
+ return {
67
+ 'image': scube_tensor, # <--- 'image' for TorchGeo
68
+ 'label': contains_methane, # <--- Index for CE Loss
69
+ 'sample': scube_path
70
+ }
sentinel2_classification_finetuning/config/train.yaml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed_everything: 42
2
+
3
+ trainer:
4
+ accelerator: auto
5
+ strategy: auto
6
+ devices: 1
7
+ max_epochs: 100
8
+ default_root_dir: ./checkpoints_s2_simulated
9
+ callbacks:
10
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
11
+ init_args:
12
+ monitor: val/loss
13
+ mode: min
14
+ save_top_k: 1
15
+ filename: "best_model"
16
+ save_last: true
17
+ - class_path: lightning.pytorch.callbacks.LearningRateMonitor
18
+ init_args:
19
+ logging_interval: epoch
20
+
21
+ model:
22
+ class_path: terratorch.tasks.ClassificationTask
23
+ init_args:
24
+ model_factory: EncoderDecoderFactory
25
+ loss: ce
26
+ ignore_index: -1
27
+ lr: 1.0e-5 # Top-level LR
28
+
29
+ optimizer: AdamW
30
+ optimizer_hparams:
31
+ weight_decay: 0.05
32
+
33
+ scheduler: ReduceLROnPlateau
34
+ scheduler_hparams:
35
+ mode: min
36
+ patience: 5
37
+
38
+ model_args:
39
+ backbone: terramind_v1_base
40
+ backbone_pretrained: true
41
+ backbone_modalities:
42
+ - S2L2A
43
+ backbone_merge_method: mean
44
+
45
+ decoder: UperNetDecoder
46
+ decoder_scale_modules: true
47
+ decoder_channels: 256
48
+ num_classes: 2
49
+ head_dropout: 0.3
50
+
51
+ necks:
52
+ - name: ReshapeTokensToImage
53
+ remove_cls_token: false
54
+ - name: SelectIndices
55
+ indices: [2, 5, 8, 11]
56
+
57
+ data:
58
+ class_path: methane_simulated_datamodule.MethaneSimulatedDataModule
59
+ init_args:
60
+ data_root: /path/to/data_root # UPDATE THIS
61
+ excel_file: ../../Methane_benchmark_patches_summary_v3.xlsx
62
+ batch_size: 8
63
+ val_split: 0.2
64
+ seed: 42
65
+ test_fold: 4
66
+ num_folds: 5
67
+ sim_tag: toarefl # Change to 'boarefl' if needed
sentinel2_classification_finetuning/script/inference_s2_simulated.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import csv
4
+ import random
5
+ import warnings
6
+ import time
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Dict, List, Tuple, Any, Optional
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import torch
14
+ import torch.nn as nn
15
+ import albumentations as A
16
+ from torch.utils.data import DataLoader
17
+ from tqdm import tqdm
18
+ from rasterio.errors import NotGeoreferencedWarning
19
+
20
+ # --- CRITICAL IMPORTS ---
21
+ import terramind
22
+ from terratorch.tasks import ClassificationTask
23
+
24
+ # Local Imports
25
+ from methane_simulated_datamodule import MethaneSimulatedDataModule
26
+
27
+ # --- Configuration & Setup ---
28
+
29
+ logging.basicConfig(
30
+ level=logging.INFO,
31
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
32
+ datefmt='%Y-%m-%d %H:%M:%S'
33
+ )
34
+ logger = logging.getLogger(__name__)
35
+
36
+ logging.getLogger("rasterio._env").setLevel(logging.ERROR)
37
+ warnings.simplefilter("ignore", NotGeoreferencedWarning)
38
+ warnings.filterwarnings("ignore", category=FutureWarning)
39
+
40
+ def set_seed(seed: int = 42):
41
+ random.seed(seed)
42
+ np.random.seed(seed)
43
+ torch.manual_seed(seed)
44
+ if torch.cuda.is_available():
45
+ torch.cuda.manual_seed_all(seed)
46
+
47
+ def get_inference_transforms() -> A.Compose:
48
+ return None
49
+
50
+ # --- Path Utilities (Crucial for Simulated Data) ---
51
+
52
+ def get_simulated_paths(paths: List[str]) -> List[str]:
53
+ """
54
+ Modifies filenames to match the simulated dataset naming convention.
55
+ Original: 'MBD_0001_S2_...' -> Simulated: 'MBD_toarefl_S2_...'
56
+ """
57
+ simulated_paths = []
58
+ for path in paths:
59
+ try:
60
+ tokens = path.split('_')
61
+ # Reconstruct filename based on notebook logic
62
+ if len(tokens) >= 5:
63
+ # e.g., MBD_toarefl_S2_123_456
64
+ simulated_path = f"{tokens[0]}_toarefl_{tokens[3]}_{tokens[4]}"
65
+ simulated_paths.append(simulated_path)
66
+ else:
67
+ simulated_paths.append(path)
68
+ except Exception as e:
69
+ logger.warning(f"Could not parse path {path}: {e}")
70
+ simulated_paths.append(path)
71
+ return simulated_paths
72
+
73
+ # --- Inference Class ---
74
+
75
+ class SimulatedInference:
76
+ def __init__(self, args: argparse.Namespace):
77
+ self.args = args
78
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
79
+ self.output_dir = Path(args.output_dir)
80
+ self.output_dir.mkdir(parents=True, exist_ok=True)
81
+
82
+ logger.info(f"Initializing Inference on device: {self.device}")
83
+
84
+ self.model = self._init_model()
85
+ self._load_checkpoint(args.checkpoint)
86
+
87
+ def _init_model(self) -> nn.Module:
88
+ model_args = dict(
89
+ backbone="terramind_v1_base",
90
+ backbone_pretrained=False,
91
+ backbone_modalities=["S2L2A"],
92
+ backbone_merge_method="mean",
93
+ decoder="UperNetDecoder",
94
+ decoder_scale_modules=True,
95
+ decoder_channels=256,
96
+ num_classes=2,
97
+ head_dropout=0.3,
98
+ necks=[
99
+ {"name": "ReshapeTokensToImage", "remove_cls_token": False},
100
+ {"name": "SelectIndices", "indices": [2, 5, 8, 11]},
101
+ {"name": "LearnedInterpolateToPyramidal"},
102
+ ],
103
+ )
104
+
105
+ task = ClassificationTask(
106
+ model_args=model_args,
107
+ model_factory="EncoderDecoderFactory",
108
+ loss="ce",
109
+ ignore_index=-1
110
+ )
111
+ task.configure_models()
112
+ return task.model.to(self.device)
113
+
114
+ def _load_checkpoint(self, checkpoint_path: str):
115
+ path = Path(checkpoint_path)
116
+ if not path.exists():
117
+ raise FileNotFoundError(f"Checkpoint not found at {path}")
118
+
119
+ logger.info(f"Loading weights from {path}...")
120
+ checkpoint = torch.load(path, map_location=self.device)
121
+
122
+ if 'state_dict' in checkpoint:
123
+ state_dict = checkpoint['state_dict']
124
+ else:
125
+ state_dict = checkpoint
126
+
127
+ self.model.load_state_dict(state_dict, strict=False)
128
+ self.model.eval()
129
+
130
+ def run_inference(self, dataloader: DataLoader, sample_names: List[str]):
131
+ """
132
+ Generates predictions and matches them with provided sample identifiers.
133
+ """
134
+ results = {}
135
+
136
+ logger.info(f"Starting inference on {len(sample_names)} samples...")
137
+
138
+ # Iterator to match predictions with original filenames
139
+ name_iter = iter(sample_names)
140
+
141
+ with torch.no_grad():
142
+ for batch in tqdm(dataloader, desc="Predicting"):
143
+ inputs = batch['S2L2A'].to(self.device)
144
+
145
+ # Forward Pass
146
+ outputs = self.model(x={"S2L2A": inputs})
147
+ probabilities = torch.softmax(outputs.output, dim=1)
148
+
149
+ # Get binary prediction (0 or 1)
150
+ predictions = torch.argmax(probabilities, dim=1)
151
+ batch_preds = predictions.cpu().numpy()
152
+
153
+ # Assign to Sample Names
154
+ for pred in batch_preds:
155
+ try:
156
+ sample_name = next(name_iter)
157
+ results[sample_name] = int(pred)
158
+ except StopIteration:
159
+ logger.error("More predictions than sample names! Check sync.")
160
+ break
161
+
162
+ if len(results) != len(sample_names):
163
+ logger.warning(f"Mismatch: Expected {len(sample_names)} results, got {len(results)}.")
164
+
165
+ # Save CSV
166
+ self._save_results(results)
167
+
168
+ def _save_results(self, results: Dict[str, int]):
169
+ csv_path = self.output_dir / "simulated_predictions.csv"
170
+ with open(csv_path, mode='w', newline='') as f:
171
+ writer = csv.writer(f)
172
+ writer.writerow(['Sample_ID', 'Prediction'])
173
+ for sample, pred in results.items():
174
+ writer.writerow([sample, pred])
175
+ logger.info(f"Predictions saved to {csv_path}")
176
+
177
+ # --- Data Loading ---
178
+
179
+ def get_dataloader_and_names(args) -> Tuple[DataLoader, List[str]]:
180
+ # 1. Read Excel to get base filenames
181
+ try:
182
+ df = pd.read_excel(args.excel_file)
183
+ # If specific folds are requested, filter them
184
+ if args.folds:
185
+ folds_to_use = [int(f) for f in args.folds.split(',')]
186
+ df = df[df['Fold'].isin(folds_to_use)]
187
+ logger.info(f"Filtered to folds: {folds_to_use}")
188
+
189
+ raw_paths = df['Filename'].tolist()
190
+ logger.info(f"Loaded {len(raw_paths)} paths from Excel.")
191
+ except Exception as e:
192
+ logger.error(f"Error reading Excel: {e}")
193
+ raise
194
+
195
+ # 2. Transform paths to Simulated format
196
+ simulated_paths = get_simulated_paths(raw_paths)
197
+
198
+ # 3. Initialize DataModule
199
+ datamodule = MethaneSimulatedDataModule(
200
+ data_root=args.root_dir,
201
+ excel_file=args.excel_file,
202
+ batch_size=args.batch_size,
203
+ paths=simulated_paths,
204
+ train_transform=None,
205
+ val_transform=get_inference_transforms(),
206
+ )
207
+
208
+ # 4. Setup
209
+ datamodule.paths = simulated_paths
210
+ datamodule.setup(stage="test")
211
+
212
+ # Try getting test_dataloader, else train/val
213
+ loader = datamodule.test_dataloader() if hasattr(datamodule, 'test_dataloader') else datamodule.train_dataloader()
214
+
215
+ return loader, simulated_paths
216
+
217
+ # --- Main Execution ---
218
+
219
+ def parse_args():
220
+ parser = argparse.ArgumentParser(description="Methane Simulated S2 Inference")
221
+
222
+ parser.add_argument('--root_dir', type=str, required=True, help='Root directory for simulated data')
223
+ parser.add_argument('--excel_file', type=str, required=True, help='Path to Summary Excel file')
224
+ parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint (.pth)')
225
+ parser.add_argument('--output_dir', type=str, default='./inference_results', help='Directory to save results')
226
+ parser.add_argument('--folds', type=str, default=None, help='Comma-separated list of folds to infer on (e.g., "4" or "1,2"). If None, uses all.')
227
+ parser.add_argument('--batch_size', type=int, default=1, help='Inference batch size')
228
+ parser.add_argument('--seed', type=int, default=42)
229
+
230
+ return parser.parse_args()
231
+
232
+ if __name__ == "__main__":
233
+ args = parse_args()
234
+ set_seed(args.seed)
235
+
236
+ # 1. Prepare Data & Names
237
+ dataloader, sample_names = get_dataloader_and_names(args)
238
+
239
+ # 2. Run Inference
240
+ engine = SimulatedInference(args)
241
+ engine.run_inference(dataloader, sample_names)
sentinel2_classification_finetuning/script/methane_simulated_datamodule.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import rasterio
3
+ import torch
4
+ from torchgeo.datasets import NonGeoDataset
5
+ from torch.utils.data import DataLoader
6
+ from torchgeo.datamodules import NonGeoDataModule
7
+ from methane_simulated_dataset import MethaneSimulatedDataset
8
+ # from methane_classification_dataset import MethaneClassificationDataset
9
+
10
+ class MethaneSimulatedDataModule(NonGeoDataModule):
11
+ """
12
+ A DataModule for handling MethaneClassificationDataset
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ data_root: str,
18
+ excel_file: str,
19
+ paths: list,
20
+ batch_size: int = 8,
21
+ num_workers: int = 0,
22
+ train_transform: callable = None,
23
+ val_transform: callable = None,
24
+ test_transform: callable = None,
25
+ **kwargs
26
+ ):
27
+ super().__init__(MethaneSimulatedDataset, batch_size, num_workers, **kwargs)
28
+
29
+ self.data_root = data_root
30
+ self.excel_file = excel_file
31
+ self.paths = paths
32
+ self.train_transform = train_transform
33
+ self.val_transform = val_transform
34
+ self.test_transform = test_transform
35
+
36
+ def setup(self, stage: str = None):
37
+ if stage in ("fit", "train"):
38
+ self.train_dataset = MethaneSimulatedDataset(
39
+ root_dir=self.data_root,
40
+ excel_file=self.excel_file,
41
+ paths=self.paths,
42
+ transform=self.train_transform,
43
+ )
44
+ if stage in ("fit", "validate", "val"):
45
+ self.val_dataset = MethaneSimulatedDataset(
46
+ root_dir=self.data_root,
47
+ excel_file=self.excel_file,
48
+ paths=self.paths,
49
+ transform=self.val_transform,
50
+ )
51
+ if stage in ("test", "predict"):
52
+ self.test_dataset = MethaneSimulatedDataset(
53
+ root_dir=self.data_root,
54
+ excel_file=self.excel_file,
55
+ paths=self.paths,
56
+ transform=self.test_transform,
57
+ )
58
+
59
+ def train_dataloader(self):
60
+ return DataLoader(
61
+ self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True
62
+ )
63
+
64
+ def val_dataloader(self):
65
+ return DataLoader(
66
+ self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
67
+ )
68
+
69
+ def test_dataloader(self):
70
+ return DataLoader(
71
+ self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
72
+ )
sentinel2_classification_finetuning/script/methane_simulated_dataset.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import rasterio
3
+ import torch
4
+ from torchgeo.datasets import NonGeoDataset
5
+ from torch.utils.data import DataLoader
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ def min_max_normalize(data, new_min=0, new_max=1):
11
+ data = np.array(data, dtype=np.float32) # Convert to NumPy array
12
+
13
+ # Handle NaN, Inf values
14
+ data = np.nan_to_num(data, nan=np.nanmin(data), posinf=np.max(data), neginf=np.min(data))
15
+
16
+ old_min, old_max = np.min(data), np.max(data)
17
+
18
+ if old_max == old_min: # Prevent division by zero
19
+ return np.full_like(data, new_min, dtype=np.float32) # Uniform array
20
+
21
+ return (data - old_min) / (old_max - old_min + 1e-10) * (new_max - new_min) + new_min
22
+
23
+ class MethaneSimulatedDataset(NonGeoDataset):
24
+ def __init__(self, root_dir, excel_file, paths, transform=None):
25
+ super().__init__()
26
+ self.root_dir = root_dir
27
+ self.transform = transform
28
+ self.data_paths = []
29
+
30
+ # Collect paths for labelbinary.tif and sCube.tif in selected folders
31
+ for folder_name in paths:
32
+ subdir_path = os.path.join(root_dir, folder_name)
33
+ if os.path.isdir(subdir_path):
34
+ label_path = os.path.join(subdir_path, folder_name + '_mask.tif')
35
+ scube_path = os.path.join(subdir_path, folder_name + '_hsi.dat')
36
+
37
+ if os.path.exists(label_path) and os.path.exists(scube_path):
38
+ self.data_paths.append((label_path, scube_path))
39
+
40
+ def __len__(self):
41
+ return len(self.data_paths)
42
+
43
+ def __getitem__(self, idx):
44
+ label_path, scube_path = self.data_paths[idx]
45
+
46
+ # Load the label image (single band)
47
+ with rasterio.open(label_path) as label_src:
48
+ label_image = label_src.read() # Shape: [512, 512]
49
+
50
+
51
+ # Load the sCube image (multi-band), drop the first band
52
+ with rasterio.open(scube_path) as scube_src:
53
+ scube_image = scube_src.read() # Shape: [13, 512, 512]
54
+
55
+ # Read only the first 12 bands for testing purposes
56
+ # Map the bands later on
57
+ scube_image = scube_image[:12, :, :]
58
+
59
+ # Convert to PyTorch tensors
60
+ scube_tensor = torch.from_numpy(scube_image).float() # Shape: [12, 512, 512]
61
+ label_tensor = torch.from_numpy(label_image).float() # Shape: [512, 512]
62
+
63
+ # Resize to [12, 224, 224] and [224, 224] respectively
64
+ scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
65
+ label_tensor = F.interpolate(label_tensor.unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0)
66
+
67
+ label_tensor = label_tensor.clip(0, 1) # Clip values to [0, 1]
68
+ scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0) # Replace NaNs with 0
69
+ # normalized_tensor = min_max_normalize(scube_tensor)
70
+ # Convert labels to binary
71
+ contains_methane = (label_tensor > 0).any().long()
72
+
73
+ # Convert to one-hot encoding
74
+ one_hot_label = F.one_hot(contains_methane, num_classes=2).float()
75
+
76
+ # Apply transformations (if any)
77
+ if self.transform:
78
+ transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0)))
79
+ scube_tensor = transformed['image'].transpose(2, 0, 1) # Convert back to [C, H, W]
80
+
81
+ return {'S2L2A': scube_tensor, 'label': one_hot_label, 'sample': scube_path}
sentinel2_classification_finetuning/script/train_simulated_s2.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import csv
4
+ import random
5
+ import warnings
6
+ import time
7
+ from pathlib import Path
8
+ from typing import Dict, List, Tuple, Any, Optional
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.optim as optim
15
+ import albumentations as A
16
+ from torch.utils.data import DataLoader
17
+ from tqdm import tqdm
18
+ from sklearn.model_selection import train_test_split
19
+ from sklearn.metrics import (
20
+ accuracy_score, recall_score, f1_score, matthews_corrcoef, confusion_matrix
21
+ )
22
+ from rasterio.errors import NotGeoreferencedWarning
23
+
24
+ # --- CRITICAL IMPORTS ---
25
+ import terramind
26
+ from terratorch.tasks import ClassificationTask
27
+
28
+ # Local Imports
29
+ from methane_simulated_datamodule import MethaneSimulatedDataModule
30
+
31
+ # --- Configuration & Setup ---
32
+
33
+ logging.basicConfig(
34
+ level=logging.INFO,
35
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
36
+ datefmt='%Y-%m-%d %H:%M:%S'
37
+ )
38
+ logger = logging.getLogger(__name__)
39
+
40
+ logging.getLogger("rasterio._env").setLevel(logging.ERROR)
41
+ warnings.simplefilter("ignore", NotGeoreferencedWarning)
42
+ warnings.filterwarnings("ignore", category=FutureWarning)
43
+
44
+ def set_seed(seed: int = 42):
45
+ random.seed(seed)
46
+ np.random.seed(seed)
47
+ torch.manual_seed(seed)
48
+ if torch.cuda.is_available():
49
+ torch.cuda.manual_seed_all(seed)
50
+
51
+ def get_training_transforms() -> A.Compose:
52
+ return A.Compose([
53
+ A.ElasticTransform(p=0.25),
54
+ A.RandomRotate90(p=0.5),
55
+ A.Flip(p=0.5),
56
+ A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5)
57
+ ])
58
+
59
+ # --- Path Utilities ---
60
+
61
+ def get_simulated_paths(paths: List[str], tag: str = "toarefl") -> List[str]:
62
+ """
63
+ Modifies filenames to match the I1/TOA naming convention.
64
+ Converts 'ang2015..._S2_...' -> 'ang2015..._{tag}_...'
65
+ """
66
+ simulated_paths = []
67
+ for path in paths:
68
+ try:
69
+ tokens = path.split('_')
70
+ # Logic: {ID}_{tag}_{Coord1}_{Coord2}
71
+ # Adjusts original filename tokens to target format
72
+ if len(tokens) >= 5:
73
+ simulated_path = f"{tokens[0]}_{tag}_{tokens[3]}_{tokens[4]}"
74
+ simulated_paths.append(simulated_path)
75
+ else:
76
+ simulated_paths.append(path)
77
+ except Exception as e:
78
+ logger.warning(f"Could not parse path {path}: {e}")
79
+ simulated_paths.append(path)
80
+ return simulated_paths
81
+
82
+ def get_paths_for_fold(excel_file: str, folds: List[int]) -> List[str]:
83
+ try:
84
+ df = pd.read_excel(excel_file)
85
+ df_filtered = df[df['Fold'].isin(folds)]
86
+ return df_filtered['Filename'].tolist()
87
+ except Exception as e:
88
+ logger.error(f"Error reading Excel file: {e}")
89
+ raise
90
+
91
+ # --- Helper Classes ---
92
+
93
+ class MetricTracker:
94
+ def __init__(self):
95
+ self.reset()
96
+
97
+ def reset(self):
98
+ self.all_targets = []
99
+ self.all_predictions = []
100
+ self.total_loss = 0.0
101
+ self.steps = 0
102
+
103
+ def update(self, loss: float, targets: torch.Tensor, probabilities: torch.Tensor):
104
+ self.total_loss += loss
105
+ self.steps += 1
106
+ self.all_targets.extend(torch.argmax(targets, dim=1).detach().cpu().numpy())
107
+ self.all_predictions.extend(torch.argmax(probabilities, dim=1).detach().cpu().numpy())
108
+
109
+ def compute(self) -> Dict[str, float]:
110
+ if not self.all_targets:
111
+ return {}
112
+
113
+ tn, fp, fn, tp = confusion_matrix(self.all_targets, self.all_predictions, labels=[0, 1]).ravel()
114
+
115
+ return {
116
+ "Loss": self.total_loss / max(self.steps, 1),
117
+ "Accuracy": accuracy_score(self.all_targets, self.all_predictions),
118
+ "Specificity": tn / (tn + fp) if (tn + fp) != 0 else 0.0,
119
+ "Sensitivity": recall_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0),
120
+ "F1": f1_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0),
121
+ "MCC": matthews_corrcoef(self.all_targets, self.all_predictions),
122
+ "TP": int(tp), "TN": int(tn), "FP": int(fp), "FN": int(fn)
123
+ }
124
+
125
+ class TrainerI1:
126
+ def __init__(self, args: argparse.Namespace):
127
+ self.args = args
128
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
129
+ self.save_dir = Path(args.save_dir) / f'fold{args.test_fold}'
130
+ self.save_dir.mkdir(parents=True, exist_ok=True)
131
+
132
+ self.model = self._init_model()
133
+ self.optimizer, self.scheduler = self._init_optimizer()
134
+ self.criterion = self.task.criterion
135
+ self.best_val_loss = float('inf')
136
+
137
+ logger.info(f"Trainer initialized on device: {self.device}")
138
+
139
+ def _init_model(self) -> nn.Module:
140
+ model_args = dict(
141
+ backbone="terramind_v1_base",
142
+ backbone_pretrained=True,
143
+ backbone_modalities=["S2L2A"],
144
+ backbone_merge_method="mean",
145
+ decoder="UperNetDecoder",
146
+ decoder_scale_modules=True,
147
+ decoder_channels=256,
148
+ num_classes=2,
149
+ head_dropout=0.3,
150
+ necks=[
151
+ {"name": "ReshapeTokensToImage", "remove_cls_token": False},
152
+ {"name": "SelectIndices", "indices": [2, 5, 8, 11]},
153
+ {"name": "LearnedInterpolateToPyramidal"},
154
+ ],
155
+ )
156
+
157
+ self.task = ClassificationTask(
158
+ model_args=model_args,
159
+ model_factory="EncoderDecoderFactory",
160
+ loss="ce",
161
+ lr=self.args.lr,
162
+ ignore_index=-1,
163
+ optimizer="AdamW",
164
+ optimizer_hparams={"weight_decay": self.args.weight_decay},
165
+ )
166
+ self.task.configure_models()
167
+ self.task.configure_losses()
168
+ return self.task.model.to(self.device)
169
+
170
+ def _init_optimizer(self):
171
+ optimizer = optim.AdamW(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
172
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, verbose=True)
173
+ return optimizer, scheduler
174
+
175
+ def run_epoch(self, dataloader: DataLoader, stage: str = "train") -> Dict[str, float]:
176
+ is_train = stage == "train"
177
+ self.model.train() if is_train else self.model.eval()
178
+ tracker = MetricTracker()
179
+
180
+ with torch.set_grad_enabled(is_train):
181
+ pbar = tqdm(dataloader, desc=f" {stage.capitalize()}", leave=False)
182
+ for batch in pbar:
183
+ inputs = batch['S2L2A'].to(self.device)
184
+ targets = batch['label'].to(self.device)
185
+
186
+ outputs = self.model(x={"S2L2A": inputs})
187
+ probabilities = torch.softmax(outputs.output, dim=1)
188
+ loss = self.criterion(probabilities, targets)
189
+
190
+ if is_train:
191
+ self.optimizer.zero_grad()
192
+ loss.backward()
193
+ self.optimizer.step()
194
+
195
+ tracker.update(loss.item(), targets, probabilities)
196
+ pbar.set_postfix(loss=f"{loss.item():.4f}")
197
+
198
+ return tracker.compute()
199
+
200
+ def fit(self, train_loader: DataLoader, val_loader: DataLoader):
201
+ logger.info(f"Starting training for {self.args.epochs} epochs...")
202
+ start_time = time.time()
203
+
204
+ # Initialize CSV logging
205
+ csv_path = self.save_dir / 'train_val_metrics.csv'
206
+ with open(csv_path, 'w', newline='') as f:
207
+ writer = csv.writer(f)
208
+ writer.writerow([
209
+ 'Epoch', 'Train_Loss', 'Train_F1', 'Train_Acc',
210
+ 'Val_Loss', 'Val_F1', 'Val_Acc', 'Val_Spec', 'Val_Sens'
211
+ ])
212
+
213
+ for epoch in range(1, self.args.epochs + 1):
214
+ logger.info(f"Epoch {epoch}/{self.args.epochs}")
215
+
216
+ train_metrics = self.run_epoch(train_loader, stage="train")
217
+ val_metrics = self.run_epoch(val_loader, stage="validate")
218
+
219
+ self.scheduler.step(val_metrics['Loss'])
220
+
221
+ # Log to CSV
222
+ with open(csv_path, 'a', newline='') as f:
223
+ writer = csv.writer(f)
224
+ writer.writerow([
225
+ epoch,
226
+ train_metrics.get('Loss'), train_metrics.get('F1'), train_metrics.get('Accuracy'),
227
+ val_metrics.get('Loss'), val_metrics.get('F1'), val_metrics.get('Accuracy'),
228
+ val_metrics.get('Specificity'), val_metrics.get('Sensitivity')
229
+ ])
230
+
231
+ logger.info(f"Train Loss: {train_metrics['Loss']:.4f} | Val Loss: {val_metrics['Loss']:.4f} | Val F1: {val_metrics['F1']:.4f}")
232
+
233
+ # Save Best Model
234
+ if val_metrics['Loss'] < self.best_val_loss:
235
+ self.best_val_loss = val_metrics['Loss']
236
+ torch.save(self.model.state_dict(), self.save_dir / "best_model.pth")
237
+ logger.info(f"--> New best model saved")
238
+
239
+ # Save Final Model
240
+ torch.save(self.model.state_dict(), self.save_dir / "final_model.pth")
241
+ logger.info(f"Training finished in {time.time() - start_time:.2f}s")
242
+
243
+ # --- Data Utilities ---
244
+
245
+ def get_data_loaders(args) -> Tuple[DataLoader, DataLoader]:
246
+ # 1. Determine Folds
247
+ all_folds = list(range(1, args.num_folds + 1))
248
+ train_pool_folds = [f for f in all_folds if f != args.test_fold]
249
+
250
+ # 2. Get Paths
251
+ paths = get_paths_for_fold(args.excel_file, train_pool_folds)
252
+
253
+ # 3. Apply Tag (Dynamic Tagging)
254
+ paths = get_simulated_paths(paths, tag=args.sim_tag)
255
+
256
+ # 4. Train/Val Split (80/20)
257
+ train_paths, val_paths = train_test_split(paths, test_size=0.2, random_state=args.seed)
258
+
259
+ logger.info(f"Data Split - Train: {len(train_paths)}, Val: {len(val_paths)} (Test Fold: {args.test_fold})")
260
+
261
+ # 5. Initialize DataModule
262
+ datamodule = MethaneSimulatedDataModule(
263
+ data_root=args.root_dir,
264
+ excel_file=args.excel_file,
265
+ batch_size=args.batch_size,
266
+ paths=paths, # Initial dummy
267
+ train_transform=get_training_transforms(),
268
+ val_transform=None,
269
+ )
270
+
271
+ # 6. Create Loaders
272
+ datamodule.paths = train_paths
273
+ datamodule.setup(stage="fit")
274
+ train_loader = datamodule.train_dataloader()
275
+
276
+ datamodule.paths = val_paths
277
+ datamodule.setup(stage="validate")
278
+ val_loader = datamodule.val_dataloader()
279
+
280
+ return train_loader, val_loader
281
+
282
+ # --- Main Execution ---
283
+
284
+ def parse_args():
285
+ parser = argparse.ArgumentParser(description="Methane I1 (Simulated) Training")
286
+
287
+ # Paths
288
+ parser.add_argument('--root_dir', type=str, required=True, help='Root directory for I1/TOA/BOA data')
289
+ parser.add_argument('--excel_file', type=str, required=True, help='Path to Summary Excel')
290
+ parser.add_argument('--save_dir', type=str, default='./checkpoints_i1', help='Output directory')
291
+
292
+ # Simulation Tag Config
293
+ parser.add_argument('--sim_tag', type=str, default='toarefl',
294
+ help='String identifier in filename (e.g. "toarefl" or "boarefl")')
295
+
296
+ # Hyperparameters
297
+ parser.add_argument('--epochs', type=int, default=100)
298
+ parser.add_argument('--batch_size', type=int, default=2, help='Batch size (must be >1 for BatchNorm)')
299
+ parser.add_argument('--lr', type=float, default=1e-5)
300
+ parser.add_argument('--weight_decay', type=float, default=0.05)
301
+ parser.add_argument('--num_folds', type=int, default=5)
302
+ parser.add_argument('--test_fold', type=int, default=4, help='Fold ID to hold out')
303
+ parser.add_argument('--seed', type=int, default=42)
304
+
305
+ return parser.parse_args()
306
+
307
+ if __name__ == "__main__":
308
+ args = parse_args()
309
+ set_seed(args.seed)
310
+
311
+ train_loader, val_loader = get_data_loaders(args)
312
+
313
+ trainer = TrainerI1(args)
314
+ trainer.fit(train_loader, val_loader)
urban_inference/methane_urban_datamodule.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import rasterio
3
+ import torch
4
+ from torchgeo.datasets import NonGeoDataset
5
+ from torch.utils.data import DataLoader
6
+ from torchgeo.datamodules import NonGeoDataModule
7
+ from methane_urban_dataset import MethaneUrbanDataset
8
+
9
+ class MethaneUrbanDataModule(NonGeoDataModule):
10
+ """
11
+ A DataModule for handling MethaneClassificationDataset
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ data_root: str,
17
+ excel_file: str,
18
+ paths: list,
19
+ batch_size: int = 8,
20
+ num_workers: int = 0,
21
+ train_transform: callable = None,
22
+ val_transform: callable = None,
23
+ test_transform: callable = None,
24
+ **kwargs
25
+ ):
26
+ super().__init__(MethaneUrbanDataset, batch_size, num_workers, **kwargs)
27
+
28
+ self.data_root = data_root
29
+ self.excel_file = excel_file
30
+ self.paths = paths
31
+ self.train_transform = train_transform
32
+ self.val_transform = val_transform
33
+ self.test_transform = test_transform
34
+
35
+ def setup(self, stage: str = None):
36
+ if stage in ("fit", "train"):
37
+ self.train_dataset = MethaneUrbanDataset(
38
+ root_dir=self.data_root,
39
+ excel_file=self.excel_file,
40
+ paths=self.paths,
41
+ transform=self.train_transform,
42
+ )
43
+ if stage in ("fit", "validate", "val"):
44
+ self.val_dataset = MethaneUrbanDataset(
45
+ root_dir=self.data_root,
46
+ excel_file=self.excel_file,
47
+ paths=self.paths,
48
+ transform=self.val_transform,
49
+ )
50
+ if stage in ("test", "predict"):
51
+ self.test_dataset = MethaneUrbanDataset(
52
+ root_dir=self.data_root,
53
+ excel_file=self.excel_file,
54
+ paths=self.paths,
55
+ transform=self.test_transform,
56
+ )
57
+
58
+ def train_dataloader(self):
59
+ return DataLoader(
60
+ self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True
61
+ )
62
+
63
+ def val_dataloader(self):
64
+ return DataLoader(
65
+ self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
66
+ )
67
+
68
+ def test_dataloader(self):
69
+ return DataLoader(
70
+ self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
71
+ )
urban_inference/methane_urban_dataset.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import rasterio
3
+ import torch
4
+ from torchgeo.datasets import NonGeoDataset
5
+ from torch.utils.data import DataLoader
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ def min_max_normalize(data, new_min=0, new_max=1):
11
+ data = np.array(data, dtype=np.float32) # Convert to NumPy array
12
+
13
+ # Handle NaN, Inf values
14
+ data = np.nan_to_num(data, nan=np.nanmin(data), posinf=np.max(data), neginf=np.min(data))
15
+
16
+ old_min, old_max = np.min(data), np.max(data)
17
+
18
+ if old_max == old_min: # Prevent division by zero
19
+ return np.full_like(data, new_min, dtype=np.float32) # Uniform array
20
+
21
+ return (data - old_min) / (old_max - old_min + 1e-10) * (new_max - new_min) + new_min
22
+
23
+ class MethaneUrbanDataset(NonGeoDataset):
24
+ def __init__(self, root_dir, excel_file, paths, transform=None, mean=None, std=None):
25
+ super().__init__()
26
+ self.root_dir = root_dir
27
+ self.transform = transform
28
+ self.data_paths = []
29
+ self.mean = mean if mean else [0.485] * 12 # Default mean if not provided
30
+ self.std = std if std else [0.229] * 12 # Default std if not provided
31
+
32
+ # Collect paths for labelbinary.tif and sCube.tif in selected folders
33
+ for folder_name in paths:
34
+ subdir_path = next((os.path.join(root_dir, d) for d in os.listdir(root_dir) if d.startswith(folder_name) and os.path.isdir(os.path.join(root_dir, d))), None)
35
+ if subdir_path is not None and os.path.isdir(subdir_path):
36
+ label_path = os.path.join(subdir_path, 'hsi')
37
+ scube_path = os.path.join(subdir_path, 'hsi')
38
+ # print(scube_path)
39
+ if os.path.exists(scube_path):
40
+ self.data_paths.append((label_path, scube_path))
41
+
42
+ def __len__(self):
43
+ return len(self.data_paths)
44
+
45
+ def __getitem__(self, idx):
46
+ label_path, scube_path = self.data_paths[idx]
47
+
48
+
49
+ # Load the sCube image (multi-band), drop the first band
50
+ with rasterio.open(scube_path) as scube_src:
51
+ scube_image = scube_src.read() # Shape: [13, 512, 512]
52
+ scube_image = scube_image[[0,1,2,3,4,5,6,7,8,9,11,12], :, :] # Drop first band → Shape: [12, 512, 512]
53
+ # scube_image = np.zeros((12, scube_src.height, scube_src.width), dtype=np.float32) # Initialize with zeros
54
+ # selected_bands = [2, 3, 4, 5, 6, 7, 8] # Bands to read
55
+ # for i, band in enumerate(selected_bands):
56
+ # scube_image[i, :, :] = scube_src.read(band + 1)
57
+ # print(scube_path, scube_image.shape)
58
+
59
+ # Convert to PyTorch tensors
60
+ scube_tensor = torch.from_numpy(scube_image).float() # Shape: [12, 512, 512]
61
+
62
+ # Resize to [12, 224, 224] and [224, 224] respectively
63
+ scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
64
+
65
+ scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0) # Replace NaNs with 0
66
+
67
+ # scube_tensor = (scube_tensor - scube_tensor.mean(dim=(1, 2), keepdim=True)) / (scube_tensor.std(dim=(1, 2), keepdim=True) + 1e-10)
68
+
69
+ # Convert labels to binary
70
+
71
+ contains_methane = torch.zeros(1).long()
72
+
73
+ # Convert to one-hot encoding
74
+ # one_hot_label = F.one_hot(contains_methane, num_classes=2).float()
75
+
76
+ # Apply transformations (if any)
77
+ if self.transform:
78
+ transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0)))
79
+ scube_tensor = transformed['image'].transpose(2, 0, 1) # Convert back to [C, H, W]
80
+
81
+
82
+ return {'S2L2A': scube_tensor, 'label': contains_methane, 'sample': scube_path.split('/')[2]}
urban_inference/urban_inference.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import csv
4
+ import random
5
+ import warnings
6
+ import time
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Dict, List, Tuple, Any, Optional
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import torch
14
+ import torch.nn as nn
15
+ import albumentations as A
16
+ from torch.utils.data import DataLoader
17
+ from tqdm import tqdm
18
+ from rasterio.errors import NotGeoreferencedWarning
19
+
20
+ # --- CRITICAL IMPORTS ---
21
+ import terramind
22
+ from terratorch.tasks import ClassificationTask
23
+
24
+ # Local Imports
25
+ from methane_urban_datamodule import MethaneUrbanDataModule
26
+
27
+ # --- Configuration & Setup ---
28
+
29
+ logging.basicConfig(
30
+ level=logging.INFO,
31
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
32
+ datefmt='%Y-%m-%d %H:%M:%S'
33
+ )
34
+ logger = logging.getLogger(__name__)
35
+
36
+ logging.getLogger("rasterio._env").setLevel(logging.ERROR)
37
+ warnings.simplefilter("ignore", NotGeoreferencedWarning)
38
+ warnings.filterwarnings("ignore", category=FutureWarning)
39
+
40
+ def set_seed(seed: int = 42):
41
+ random.seed(seed)
42
+ np.random.seed(seed)
43
+ torch.manual_seed(seed)
44
+ if torch.cuda.is_available():
45
+ torch.cuda.manual_seed_all(seed)
46
+
47
+ def get_inference_transforms() -> A.Compose:
48
+ return None
49
+
50
+ # --- Inference Class ---
51
+
52
+ class UrbanInference:
53
+ def __init__(self, args: argparse.Namespace):
54
+ self.args = args
55
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
56
+ self.output_dir = Path(args.output_dir)
57
+ self.output_dir.mkdir(parents=True, exist_ok=True)
58
+
59
+ logger.info(f"Initializing Inference on device: {self.device}")
60
+
61
+ self.model = self._init_model()
62
+ self._load_checkpoint(args.checkpoint)
63
+
64
+ def _init_model(self) -> nn.Module:
65
+ model_args = dict(
66
+ backbone="terramind_v1_base",
67
+ backbone_pretrained=False,
68
+ backbone_modalities=["S2L2A"],
69
+ backbone_merge_method="mean",
70
+ decoder="UperNetDecoder",
71
+ decoder_scale_modules=True,
72
+ decoder_channels=256,
73
+ num_classes=2,
74
+ head_dropout=0.3,
75
+ necks=[
76
+ {"name": "ReshapeTokensToImage", "remove_cls_token": False},
77
+ {"name": "SelectIndices", "indices": [2, 5, 8, 11]},
78
+ ],
79
+ )
80
+
81
+ task = ClassificationTask(
82
+ model_args=model_args,
83
+ model_factory="EncoderDecoderFactory",
84
+ loss="ce",
85
+ ignore_index=-1
86
+ )
87
+ task.configure_models()
88
+ return task.model.to(self.device)
89
+
90
+ def _load_checkpoint(self, checkpoint_path: str):
91
+ path = Path(checkpoint_path)
92
+ if not path.exists():
93
+ raise FileNotFoundError(f"Checkpoint not found at {path}")
94
+
95
+ logger.info(f"Loading weights from {path}...")
96
+ checkpoint = torch.load(path, map_location=self.device)
97
+
98
+ if 'state_dict' in checkpoint:
99
+ state_dict = checkpoint['state_dict']
100
+ else:
101
+ state_dict = checkpoint
102
+
103
+ self.model.load_state_dict(state_dict, strict=False)
104
+ self.model.eval()
105
+
106
+ def run_inference(self, dataloader: DataLoader, sample_names: List[str]):
107
+ """
108
+ Generates binary predictions and matches them with provided sample_names (folder names).
109
+ """
110
+ sample_results = {}
111
+
112
+ logger.info(f"Starting inference on {len(sample_names)} samples...")
113
+
114
+ # Iterator for sample names to match sequential predictions
115
+ name_iter = iter(sample_names)
116
+
117
+ with torch.no_grad():
118
+ for batch in tqdm(dataloader, desc="Predicting"):
119
+ inputs = batch['S2L2A'].to(self.device)
120
+
121
+ # Forward Pass
122
+ outputs = self.model(x={"S2L2A": inputs})
123
+ probabilities = torch.softmax(outputs.output, dim=1)
124
+
125
+ # Get binary prediction (0 or 1)
126
+ predictions = torch.argmax(probabilities, dim=1)
127
+ batch_preds = predictions.cpu().numpy()
128
+
129
+ # Assign Directory Names to Predictions
130
+ for pred in batch_preds:
131
+ try:
132
+ dir_name = next(name_iter)
133
+ sample_results[dir_name] = int(pred)
134
+ except StopIteration:
135
+ logger.error("More predictions generated than sample names provided! Check dataloader sync.")
136
+ break
137
+
138
+ # Check if we missed any samples
139
+ if len(sample_results) != len(sample_names):
140
+ logger.warning(f"Mismatch: Expected {len(sample_names)} results, got {len(sample_results)}.")
141
+
142
+ # Save CSV
143
+ self._save_results(sample_results)
144
+
145
+ def _save_results(self, results: Dict[str, int]):
146
+ csv_path = self.output_dir / "inference_predictions.csv"
147
+ with open(csv_path, mode='w', newline='') as f:
148
+ writer = csv.writer(f)
149
+ writer.writerow(['Sample_Directory', 'Prediction'])
150
+ for sample, pred in results.items():
151
+ writer.writerow([sample, pred])
152
+ logger.info(f"Predictions saved to {csv_path}")
153
+
154
+ # --- Data Loading ---
155
+
156
+ def get_dataloader_and_names(args) -> Tuple[DataLoader, List[str]]:
157
+ root_path = Path(args.root_dir)
158
+ if not root_path.exists():
159
+ raise FileNotFoundError(f"Data directory {args.root_dir} not found.")
160
+
161
+ paths = None
162
+ if args.excel_file:
163
+ try:
164
+ df = pd.read_excel(args.excel_file)
165
+ # Ensure we get the directory name prefix if using 'Filename' column
166
+ paths = df['Filename'].apply(lambda x: str(x).split('_')[0]).tolist()
167
+ logger.info(f"Filtered {len(paths)} samples from Excel.")
168
+ except Exception as e:
169
+ logger.error(f"Error reading Excel: {e}")
170
+ raise
171
+
172
+ if paths is None:
173
+ # Fallback to all subdirectories
174
+ # SORTING is crucial here to match the sequential dataloader
175
+ paths = sorted([d.name for d in root_path.iterdir() if d.is_dir()])
176
+ logger.info(f"Found {len(paths)} samples in directory (Sorted).")
177
+
178
+ # Initialize DataModule
179
+ datamodule = MethaneUrbanDataModule(
180
+ data_root=args.root_dir,
181
+ excel_file=None,
182
+ batch_size=args.batch_size,
183
+ paths=paths,
184
+ train_transform=None,
185
+ val_transform=get_inference_transforms(),
186
+ test_transform=get_inference_transforms()
187
+ )
188
+
189
+ # Setup for test stage
190
+ datamodule.paths = paths
191
+ datamodule.setup(stage="test")
192
+
193
+ # Get loader (prefer test_dataloader)
194
+ loader = datamodule.test_dataloader() if hasattr(datamodule, 'test_dataloader') else datamodule.train_dataloader()
195
+
196
+ return loader, paths
197
+
198
+ # --- Main Execution ---
199
+
200
+ def parse_args():
201
+ parser = argparse.ArgumentParser(description="Methane Urban Inference (Directory Names)")
202
+
203
+ parser.add_argument('--root_dir', type=str, required=True, help='Root directory containing sample folders')
204
+ parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint (.pth)')
205
+ parser.add_argument('--excel_file', type=str, help='Optional Excel file to filter specific samples')
206
+ parser.add_argument('--output_dir', type=str, default='./inference_results', help='Directory to save results')
207
+ parser.add_argument('--batch_size', type=int, default=1, help='Inference batch size')
208
+ parser.add_argument('--seed', type=int, default=42)
209
+
210
+ return parser.parse_args()
211
+
212
+ if __name__ == "__main__":
213
+ args = parse_args()
214
+ set_seed(args.seed)
215
+
216
+ # 1. Prepare Data & Capture Directory Names
217
+ dataloader, sample_names = get_dataloader_and_names(args)
218
+
219
+ # 2. Run Inference
220
+ engine = UrbanInference(args)
221
+ engine.run_inference(dataloader, sample_names)