Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- Methane_benchmark_patches_summary_v3.xlsx +3 -0
- README.md +27 -0
- classification/config/methane_classification_datamodule.py +116 -0
- classification/config/methane_classification_dataset.py +79 -0
- classification/config/train.yaml +85 -0
- classification/script/methane_classification_datamodule.py +71 -0
- classification/script/methane_classification_dataset.py +82 -0
- classification/script/train_classification_fine_tuning.py +329 -0
- classification_with_text/calculate_embeddings.py +51 -0
- classification_with_text/combined_caption_embeddings.csv +0 -0
- classification_with_text/script/methan_text_dataset.py +81 -0
- classification_with_text/script/methane_text_datamodule.py +72 -0
- classification_with_text/script/train_text.py +448 -0
- intuition1_classification_finetuning/config/methane_simulated_datamodule.py +116 -0
- intuition1_classification_finetuning/config/methane_simulated_dataset.py +66 -0
- intuition1_classification_finetuning/config/train.yaml +66 -0
- intuition1_classification_finetuning/script/methane_simulated_datamodule.py +72 -0
- intuition1_classification_finetuning/script/methane_simulated_dataset.py +81 -0
- intuition1_classification_finetuning/script/train_simulated_I1.py +309 -0
- sentinel2_classification_finetuning/config/methane_simulated_datamodule.py +119 -0
- sentinel2_classification_finetuning/config/methane_simulated_dataset.py +70 -0
- sentinel2_classification_finetuning/config/train.yaml +67 -0
- sentinel2_classification_finetuning/script/inference_s2_simulated.py +241 -0
- sentinel2_classification_finetuning/script/methane_simulated_datamodule.py +72 -0
- sentinel2_classification_finetuning/script/methane_simulated_dataset.py +81 -0
- sentinel2_classification_finetuning/script/train_simulated_s2.py +314 -0
- urban_inference/methane_urban_datamodule.py +71 -0
- urban_inference/methane_urban_dataset.py +82 -0
- urban_inference/urban_inference.py +221 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
Methane_benchmark_patches_summary_v3.xlsx filter=lfs diff=lfs merge=lfs -text
|
Methane_benchmark_patches_summary_v3.xlsx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:522e4342850c50e22decd218f5bd546493931f57cd10642cf6d42d2224ad4890
|
| 3 |
+
size 349508
|
README.md
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FAST-EO Use Case 2 - Methane Detection
|
| 2 |
+
|
| 3 |
+
This directory contains all data and code necessary to recreate experiments conducted for fine-tuning Terramind-Base to detect methane in satellite images. It includes five distinct experiments along with their corresponding datasets. The attached Methane_benchmark_patches_summary_v3.xlsx file provides descriptions for every patch extracted from the Methane Benchmark Dataset (MBD) and defines the fold splits to ensure non-overlapping data. This Excel file is used by the runner scripts to partition the data, typically reserving one fold for testing.
|
| 4 |
+
|
| 5 |
+
Each script includes usage instructions which can be accessed by applying the --help (or -h) flag.
|
| 6 |
+
|
| 7 |
+
## Important: Ensure the Terramind package is installed before running any experiments.
|
| 8 |
+
|
| 9 |
+
## Experiment 1: Fine tuning on Methane Benchmark Dataset
|
| 10 |
+
|
| 11 |
+
The first experiment is fine tuning the model on Methane Benchmark Dataset. The dataset has been attatched in the directory `MBD_nan_S2_zscore`, and has been already normalized. The code for running the training is located in the `classification` directory, along with neccessary `dataset` and `dataloader` classes.
|
| 12 |
+
|
| 13 |
+
## Experiment 2: Fine tuning on MBD with text captions
|
| 14 |
+
|
| 15 |
+
This experiment contains a modified verion of the Terramind-Based model, which concatinates the textual embeddings of the text captions for every image, with the visual embeddings of the base model. The text embeddings are calculated using the `all-MiniLM-L6-v2` model. All the code, along with embeddings calculation, and data, is available in the `classification_with_text` directory. The original captions are located in `classification_with_text/MBD_text`, and the embeddings are located inside the `combined_caption_embeddings.csv` file.
|
| 16 |
+
|
| 17 |
+
## Experiment 3: Fine tuning and inference on Sentinel 2 with simulated atmospheric conditions
|
| 18 |
+
|
| 19 |
+
This experiment checks how the Terramind-Base behaves on the Sentinel-2 data with simulated atmospheric conditions. The simulated data is both in the Top-of-Atmosphere and Bottom-of-Atmsphere variants. The model can be both trained on this data, or only run on it, to test how good it is at generalization when trained on different data.
|
| 20 |
+
|
| 21 |
+
## Experiment 4: Fine tuning and inference on Intuition 1 with simulated atmosphric conditions
|
| 22 |
+
|
| 23 |
+
This experiment checks how the Terramind-Base behaves on the Intuition-1 data with simulated atmospheric conditions. The simulated data is both in the Top-of-Atmosphere and Bottom-of-Atmsphere variants. The model can be both trained on this data, or only run on it, to test how good it is at generalization when trained on different data.
|
| 24 |
+
|
| 25 |
+
## Experiment 5: Testing the detector on urban dataset without methane
|
| 26 |
+
|
| 27 |
+
The urban dataset has been prepared to check whether the models really learned to detect methane from multispectral data or just look for urban signatures in the images. All of the images in this dataset do not contain methane, the goal is to run the models and see how many false positives are returned. Python script for loading and running the models were attatched.
|
classification/config/methane_classification_datamodule.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import albumentations as A
|
| 3 |
+
from typing import Optional, List
|
| 4 |
+
from sklearn.model_selection import train_test_split
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torchgeo.datamodules import NonGeoDataModule
|
| 7 |
+
from methane_classification_dataset import MethaneClassificationDataset
|
| 8 |
+
|
| 9 |
+
class MethaneClassificationDataModule(NonGeoDataModule):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
data_root: str,
|
| 13 |
+
excel_file: str,
|
| 14 |
+
batch_size: int = 8,
|
| 15 |
+
num_workers: int = 0,
|
| 16 |
+
val_split: float = 0.2,
|
| 17 |
+
seed: int = 42,
|
| 18 |
+
**kwargs
|
| 19 |
+
):
|
| 20 |
+
# We pass "NonGeoDataset" just to satisfy the parent class,
|
| 21 |
+
# but we instantiate specific datasets in setup()
|
| 22 |
+
super().__init__(MethaneClassificationDataset, batch_size, num_workers, **kwargs)
|
| 23 |
+
|
| 24 |
+
self.data_root = data_root
|
| 25 |
+
self.excel_file = excel_file
|
| 26 |
+
self.val_split = val_split
|
| 27 |
+
self.seed = seed
|
| 28 |
+
self.batch_size = batch_size
|
| 29 |
+
self.num_workers = num_workers
|
| 30 |
+
|
| 31 |
+
# State variables for paths
|
| 32 |
+
self.train_paths = []
|
| 33 |
+
self.val_paths = []
|
| 34 |
+
|
| 35 |
+
def _get_training_transforms(self):
|
| 36 |
+
"""Internal definition of training transforms"""
|
| 37 |
+
return A.Compose([
|
| 38 |
+
A.ElasticTransform(p=0.25),
|
| 39 |
+
A.RandomRotate90(p=0.5),
|
| 40 |
+
A.Flip(p=0.5),
|
| 41 |
+
A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5)
|
| 42 |
+
])
|
| 43 |
+
|
| 44 |
+
def setup(self, stage: str = None):
|
| 45 |
+
# 1. Read the Excel File
|
| 46 |
+
try:
|
| 47 |
+
df = pd.read_csv(self.excel_file) if self.excel_file.endswith('.csv') else pd.read_excel(self.excel_file)
|
| 48 |
+
except Exception as e:
|
| 49 |
+
raise RuntimeError(f"Failed to load summary file: {e}")
|
| 50 |
+
|
| 51 |
+
# 2. Filter valid paths (checking if Fold column exists or just using all data)
|
| 52 |
+
# Assuming we just use all data in the file and split it 80/20 here.
|
| 53 |
+
# If you need specific Fold filtering, add that logic here.
|
| 54 |
+
all_paths = df['Filename'].tolist()
|
| 55 |
+
|
| 56 |
+
# 3. Perform the Split
|
| 57 |
+
self.train_paths, self.val_paths = train_test_split(
|
| 58 |
+
all_paths,
|
| 59 |
+
test_size=self.val_split,
|
| 60 |
+
random_state=self.seed
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# 4. Instantiate Datasets
|
| 64 |
+
if stage in ("fit", "train"):
|
| 65 |
+
self.train_dataset = MethaneClassificationDataset(
|
| 66 |
+
root_dir=self.data_root,
|
| 67 |
+
excel_file=self.excel_file,
|
| 68 |
+
paths=self.train_paths,
|
| 69 |
+
transform=self._get_training_transforms(),
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if stage in ("fit", "validate", "val"):
|
| 73 |
+
self.val_dataset = MethaneClassificationDataset(
|
| 74 |
+
root_dir=self.data_root,
|
| 75 |
+
excel_file=self.excel_file,
|
| 76 |
+
paths=self.val_paths,
|
| 77 |
+
transform=None, # No transforms for validation
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
if stage in ("test", "predict"):
|
| 81 |
+
# For testing, you might want to use a specific hold-out set
|
| 82 |
+
# For now, reusing val_paths or you can add logic to load a test fold
|
| 83 |
+
self.test_dataset = MethaneClassificationDataset(
|
| 84 |
+
root_dir=self.data_root,
|
| 85 |
+
excel_file=self.excel_file,
|
| 86 |
+
paths=self.val_paths,
|
| 87 |
+
transform=None,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def train_dataloader(self):
|
| 92 |
+
return DataLoader(
|
| 93 |
+
self.train_dataset,
|
| 94 |
+
batch_size=self.batch_size,
|
| 95 |
+
shuffle=True,
|
| 96 |
+
num_workers=self.num_workers,
|
| 97 |
+
drop_last=True
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def val_dataloader(self):
|
| 101 |
+
return DataLoader(
|
| 102 |
+
self.val_dataset,
|
| 103 |
+
batch_size=self.batch_size,
|
| 104 |
+
shuffle=False,
|
| 105 |
+
num_workers=self.num_workers,
|
| 106 |
+
drop_last=True
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
def test_dataloader(self):
|
| 110 |
+
return DataLoader(
|
| 111 |
+
self.test_dataset,
|
| 112 |
+
batch_size=self.batch_size,
|
| 113 |
+
shuffle=False,
|
| 114 |
+
num_workers=self.num_workers,
|
| 115 |
+
drop_last=True
|
| 116 |
+
)
|
classification/config/methane_classification_dataset.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import rasterio
|
| 3 |
+
import torch
|
| 4 |
+
from torchgeo.datasets import NonGeoDataset
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
def min_max_normalize(data, new_min=0, new_max=1):
|
| 11 |
+
data = np.array(data, dtype=np.float32) # Convert to NumPy array
|
| 12 |
+
|
| 13 |
+
# Handle NaN, Inf values
|
| 14 |
+
data = np.nan_to_num(data, nan=np.nanmin(data), posinf=np.max(data), neginf=np.min(data))
|
| 15 |
+
|
| 16 |
+
old_min, old_max = np.min(data), np.max(data)
|
| 17 |
+
|
| 18 |
+
if old_max == old_min: # Prevent division by zero
|
| 19 |
+
return np.full_like(data, new_min, dtype=np.float32) # Uniform array
|
| 20 |
+
|
| 21 |
+
return (data - old_min) / (old_max - old_min + 1e-10) * (new_max - new_min) + new_min
|
| 22 |
+
|
| 23 |
+
class MethaneClassificationDataset(NonGeoDataset):
|
| 24 |
+
def __init__(self, root_dir, excel_file, paths, transform=None, mean=None, std=None):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.root_dir = root_dir
|
| 27 |
+
self.transform = transform
|
| 28 |
+
self.data_paths = []
|
| 29 |
+
self.mean = mean if mean else [0.485] * 12 # Default mean if not provided
|
| 30 |
+
self.std = std if std else [0.229] * 12 # Default std if not provided
|
| 31 |
+
|
| 32 |
+
# Collect paths for labelbinary.tif and sCube.tif in selected folders
|
| 33 |
+
for folder_name in paths:
|
| 34 |
+
subdir_path = os.path.join(root_dir, folder_name)
|
| 35 |
+
if os.path.isdir(subdir_path):
|
| 36 |
+
label_path = os.path.join(subdir_path, 'labelbinary.tif')
|
| 37 |
+
scube_path = os.path.join(subdir_path, 'sCube.tif')
|
| 38 |
+
|
| 39 |
+
if os.path.exists(label_path) and os.path.exists(scube_path):
|
| 40 |
+
self.data_paths.append((label_path, scube_path))
|
| 41 |
+
|
| 42 |
+
def __len__(self):
|
| 43 |
+
return len(self.data_paths)
|
| 44 |
+
|
| 45 |
+
def __getitem__(self, idx):
|
| 46 |
+
label_path, scube_path = self.data_paths[idx]
|
| 47 |
+
|
| 48 |
+
# Load the label image (single band)
|
| 49 |
+
with rasterio.open(label_path) as label_src:
|
| 50 |
+
label_image = label_src.read(1) # Shape: [512, 512]
|
| 51 |
+
|
| 52 |
+
# Load the sCube image (multi-band), drop the first band
|
| 53 |
+
with rasterio.open(scube_path) as scube_src:
|
| 54 |
+
scube_image = scube_src.read() # Shape: [13, 512, 512]
|
| 55 |
+
|
| 56 |
+
scube_image = scube_image[[0,1,2,3,4,5,6,7,8,9,11,12], :, :] # Drop first band → Shape: [12, 512, 512]
|
| 57 |
+
|
| 58 |
+
# Convert to PyTorch tensors
|
| 59 |
+
scube_tensor = torch.from_numpy(scube_image).float() # Shape: [12, 512, 512]
|
| 60 |
+
label_tensor = torch.from_numpy(label_image).float() # Shape: [512, 512]
|
| 61 |
+
|
| 62 |
+
# Resize to [12, 224, 224] and [224, 224] respectively
|
| 63 |
+
scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
|
| 64 |
+
label_tensor = F.interpolate(label_tensor.unsqueeze(0).unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0)
|
| 65 |
+
|
| 66 |
+
label_tensor = label_tensor.clip(0, 1) # Clip values to [0, 1]
|
| 67 |
+
scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0) # Replace NaNs with 0
|
| 68 |
+
|
| 69 |
+
# Convert labels to binary
|
| 70 |
+
contains_methane = (label_tensor > 0).any().long()
|
| 71 |
+
|
| 72 |
+
# Apply transformations (if any)
|
| 73 |
+
if self.transform:
|
| 74 |
+
transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0)))
|
| 75 |
+
scube_tensor = transformed['image'].transpose(2, 0, 1) # Convert back to [C, H, W]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
return {'image': scube_tensor, 'label': contains_methane, 'gt': label_image, 'sample': scube_path.split('/')[3]}
|
| 79 |
+
|
classification/config/train.yaml
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Global Seed for reproducibility (matches script's set_seed)
|
| 2 |
+
seed_everything: 42
|
| 3 |
+
|
| 4 |
+
# ------------------------------------------------------------------
|
| 5 |
+
# Trainer Configuration
|
| 6 |
+
# ------------------------------------------------------------------
|
| 7 |
+
trainer:
|
| 8 |
+
accelerator: auto # Handles "cuda" if available, else "cpu"
|
| 9 |
+
strategy: auto
|
| 10 |
+
devices: 1
|
| 11 |
+
max_epochs: 100 # Matches args.epochs
|
| 12 |
+
default_root_dir: ./checkpoints
|
| 13 |
+
|
| 14 |
+
# Callbacks to replicate the script's checkpointing and logging
|
| 15 |
+
callbacks:
|
| 16 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 17 |
+
init_args:
|
| 18 |
+
monitor: val/loss
|
| 19 |
+
mode: min
|
| 20 |
+
save_top_k: 1
|
| 21 |
+
filename: "best_model"
|
| 22 |
+
save_last: true # Saves 'last.ckpt' (similar to 'final_model.pth')
|
| 23 |
+
|
| 24 |
+
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
|
| 25 |
+
init_args:
|
| 26 |
+
logging_interval: epoch
|
| 27 |
+
|
| 28 |
+
# ------------------------------------------------------------------
|
| 29 |
+
# Model Configuration (TerraMind + UperNet)
|
| 30 |
+
# ------------------------------------------------------------------
|
| 31 |
+
model:
|
| 32 |
+
class_path: terratorch.tasks.ClassificationTask
|
| 33 |
+
init_args:
|
| 34 |
+
model_factory: EncoderDecoderFactory
|
| 35 |
+
loss: ce
|
| 36 |
+
ignore_index: -1
|
| 37 |
+
lr: 1.0e-5
|
| 38 |
+
# Optimizer settings matching _init_optimizer
|
| 39 |
+
optimizer: AdamW
|
| 40 |
+
optimizer_hparams:
|
| 41 |
+
weight_decay: 0.05
|
| 42 |
+
|
| 43 |
+
# Scheduler settings matching ReduceLROnPlateau
|
| 44 |
+
scheduler: ReduceLROnPlateau
|
| 45 |
+
scheduler_hparams:
|
| 46 |
+
mode: min
|
| 47 |
+
patience: 5
|
| 48 |
+
|
| 49 |
+
# --------------------------------------------------------------
|
| 50 |
+
# Model Architecture (Exact match to script's model_config)
|
| 51 |
+
# --------------------------------------------------------------
|
| 52 |
+
model_args:
|
| 53 |
+
backbone: terramind_v1_base
|
| 54 |
+
backbone_pretrained: true
|
| 55 |
+
backbone_modalities:
|
| 56 |
+
- S2L2A
|
| 57 |
+
backbone_merge_method: mean
|
| 58 |
+
|
| 59 |
+
decoder: UperNetDecoder
|
| 60 |
+
decoder_scale_modules: true
|
| 61 |
+
decoder_channels: 256
|
| 62 |
+
num_classes: 2
|
| 63 |
+
head_dropout: 0.3
|
| 64 |
+
|
| 65 |
+
# Specific neck configuration for TerraMind
|
| 66 |
+
necks:
|
| 67 |
+
- name: ReshapeTokensToImage
|
| 68 |
+
remove_cls_token: false
|
| 69 |
+
- name: SelectIndices
|
| 70 |
+
indices: [2, 5, 8, 11]
|
| 71 |
+
|
| 72 |
+
# ------------------------------------------------------------------
|
| 73 |
+
# Data Configuration
|
| 74 |
+
# ------------------------------------------------------------------
|
| 75 |
+
data:
|
| 76 |
+
class_path: methane_classification_datamodule.MethaneClassificationDataModule
|
| 77 |
+
init_args:
|
| 78 |
+
data_root: ../../MBD_nan_S2_zscore/MBD_nan_S2_zscore
|
| 79 |
+
excel_file: ../../Methane_benchmark_patches_summary_v3.xlsx
|
| 80 |
+
batch_size: 8
|
| 81 |
+
val_split: 0.2
|
| 82 |
+
seed: 42
|
| 83 |
+
# Note: The procedural train_test_split logic from the script
|
| 84 |
+
# (handling folds/splitting) should be encapsulated inside the
|
| 85 |
+
# DataModule's setup() method for this config to work seamlessly.
|
classification/script/methane_classification_datamodule.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import rasterio
|
| 3 |
+
import torch
|
| 4 |
+
from torchgeo.datasets import NonGeoDataset
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torchgeo.datamodules import NonGeoDataModule
|
| 7 |
+
from methane_classification_dataset import MethaneClassificationDataset
|
| 8 |
+
|
| 9 |
+
class MethaneClassificationDataModule(NonGeoDataModule):
|
| 10 |
+
"""
|
| 11 |
+
A DataModule for handling MethaneClassificationDataset
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
data_root: str,
|
| 17 |
+
excel_file: str,
|
| 18 |
+
paths: list,
|
| 19 |
+
batch_size: int = 8,
|
| 20 |
+
num_workers: int = 0,
|
| 21 |
+
train_transform: callable = None,
|
| 22 |
+
val_transform: callable = None,
|
| 23 |
+
test_transform: callable = None,
|
| 24 |
+
**kwargs
|
| 25 |
+
):
|
| 26 |
+
super().__init__(MethaneClassificationDataset, batch_size, num_workers, **kwargs)
|
| 27 |
+
|
| 28 |
+
self.data_root = data_root
|
| 29 |
+
self.excel_file = excel_file
|
| 30 |
+
self.paths = paths
|
| 31 |
+
self.train_transform = train_transform
|
| 32 |
+
self.val_transform = val_transform
|
| 33 |
+
self.test_transform = test_transform
|
| 34 |
+
|
| 35 |
+
def setup(self, stage: str = None):
|
| 36 |
+
if stage in ("fit", "train"):
|
| 37 |
+
self.train_dataset = MethaneClassificationDataset(
|
| 38 |
+
root_dir=self.data_root,
|
| 39 |
+
excel_file=self.excel_file,
|
| 40 |
+
paths=self.paths,
|
| 41 |
+
transform=self.train_transform,
|
| 42 |
+
)
|
| 43 |
+
if stage in ("fit", "validate", "val"):
|
| 44 |
+
self.val_dataset = MethaneClassificationDataset(
|
| 45 |
+
root_dir=self.data_root,
|
| 46 |
+
excel_file=self.excel_file,
|
| 47 |
+
paths=self.paths,
|
| 48 |
+
transform=self.val_transform,
|
| 49 |
+
)
|
| 50 |
+
if stage in ("test", "predict"):
|
| 51 |
+
self.test_dataset = MethaneClassificationDataset(
|
| 52 |
+
root_dir=self.data_root,
|
| 53 |
+
excel_file=self.excel_file,
|
| 54 |
+
paths=self.paths,
|
| 55 |
+
transform=self.test_transform,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def train_dataloader(self):
|
| 59 |
+
return DataLoader(
|
| 60 |
+
self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def val_dataloader(self):
|
| 64 |
+
return DataLoader(
|
| 65 |
+
self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def test_dataloader(self):
|
| 69 |
+
return DataLoader(
|
| 70 |
+
self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
|
| 71 |
+
)
|
classification/script/methane_classification_dataset.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import rasterio
|
| 3 |
+
import torch
|
| 4 |
+
from torchgeo.datasets import NonGeoDataset
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
def min_max_normalize(data, new_min=0, new_max=1):
|
| 11 |
+
data = np.array(data, dtype=np.float32) # Convert to NumPy array
|
| 12 |
+
|
| 13 |
+
# Handle NaN, Inf values
|
| 14 |
+
data = np.nan_to_num(data, nan=np.nanmin(data), posinf=np.max(data), neginf=np.min(data))
|
| 15 |
+
|
| 16 |
+
old_min, old_max = np.min(data), np.max(data)
|
| 17 |
+
|
| 18 |
+
if old_max == old_min: # Prevent division by zero
|
| 19 |
+
return np.full_like(data, new_min, dtype=np.float32) # Uniform array
|
| 20 |
+
|
| 21 |
+
return (data - old_min) / (old_max - old_min + 1e-10) * (new_max - new_min) + new_min
|
| 22 |
+
|
| 23 |
+
class MethaneClassificationDataset(NonGeoDataset):
|
| 24 |
+
def __init__(self, root_dir, excel_file, paths, transform=None, mean=None, std=None):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.root_dir = root_dir
|
| 27 |
+
self.transform = transform
|
| 28 |
+
self.data_paths = []
|
| 29 |
+
self.mean = mean if mean else [0.485] * 12 # Default mean if not provided
|
| 30 |
+
self.std = std if std else [0.229] * 12 # Default std if not provided
|
| 31 |
+
|
| 32 |
+
# Collect paths for labelbinary.tif and sCube.tif in selected folders
|
| 33 |
+
for folder_name in paths:
|
| 34 |
+
subdir_path = os.path.join(root_dir, folder_name)
|
| 35 |
+
if os.path.isdir(subdir_path):
|
| 36 |
+
label_path = os.path.join(subdir_path, 'labelbinary.tif')
|
| 37 |
+
scube_path = os.path.join(subdir_path, 'sCube.tif')
|
| 38 |
+
|
| 39 |
+
if os.path.exists(label_path) and os.path.exists(scube_path):
|
| 40 |
+
self.data_paths.append((label_path, scube_path))
|
| 41 |
+
|
| 42 |
+
def __len__(self):
|
| 43 |
+
return len(self.data_paths)
|
| 44 |
+
|
| 45 |
+
def __getitem__(self, idx):
|
| 46 |
+
label_path, scube_path = self.data_paths[idx]
|
| 47 |
+
|
| 48 |
+
# Load the label image (single band)
|
| 49 |
+
with rasterio.open(label_path) as label_src:
|
| 50 |
+
label_image = label_src.read(1) # Shape: [512, 512]
|
| 51 |
+
|
| 52 |
+
# Load the sCube image (multi-band), drop the first band
|
| 53 |
+
with rasterio.open(scube_path) as scube_src:
|
| 54 |
+
scube_image = scube_src.read() # Shape: [13, 512, 512]
|
| 55 |
+
# Zrobić tak żeby nie było 10 bandu
|
| 56 |
+
|
| 57 |
+
scube_image = scube_image[[0,1,2,3,4,5,6,7,8,9,11,12], :, :] # Drop first band → Shape: [12, 512, 512]
|
| 58 |
+
|
| 59 |
+
# Convert to PyTorch tensors
|
| 60 |
+
scube_tensor = torch.from_numpy(scube_image).float() # Shape: [12, 512, 512]
|
| 61 |
+
label_tensor = torch.from_numpy(label_image).float() # Shape: [512, 512]
|
| 62 |
+
|
| 63 |
+
# Resize to [12, 224, 224] and [224, 224] respectively
|
| 64 |
+
scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
|
| 65 |
+
label_tensor = F.interpolate(label_tensor.unsqueeze(0).unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0)
|
| 66 |
+
|
| 67 |
+
label_tensor = label_tensor.clip(0, 1) # Clip values to [0, 1]
|
| 68 |
+
scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0) # Replace NaNs with 0
|
| 69 |
+
|
| 70 |
+
# Convert labels to binary
|
| 71 |
+
contains_methane = (label_tensor > 0).any().long()
|
| 72 |
+
|
| 73 |
+
# Convert to one-hot encoding
|
| 74 |
+
one_hot_label = F.one_hot(contains_methane, num_classes=2).float()
|
| 75 |
+
|
| 76 |
+
# Apply transformations (if any)
|
| 77 |
+
if self.transform:
|
| 78 |
+
transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0)))
|
| 79 |
+
scube_tensor = transformed['image'].transpose(2, 0, 1) # Convert back to [C, H, W]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
return {'S2L2A': scube_tensor, 'label': one_hot_label, 'gt': label_image, 'sample': scube_path.split('/')[3]}
|
classification/script/train_classification_fine_tuning.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
import csv
|
| 4 |
+
import random
|
| 5 |
+
import warnings
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, List, Tuple, Any, Optional
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.optim as optim
|
| 15 |
+
import albumentations as A
|
| 16 |
+
from torch.utils.data import DataLoader
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from sklearn.model_selection import train_test_split
|
| 19 |
+
from sklearn.metrics import (
|
| 20 |
+
accuracy_score, recall_score, f1_score, matthews_corrcoef, confusion_matrix
|
| 21 |
+
)
|
| 22 |
+
from rasterio.errors import NotGeoreferencedWarning
|
| 23 |
+
import terramind
|
| 24 |
+
|
| 25 |
+
# Local Imports
|
| 26 |
+
from methane_classification_datamodule import MethaneClassificationDataModule
|
| 27 |
+
|
| 28 |
+
# TerraTorch Imports
|
| 29 |
+
from terratorch.tasks import ClassificationTask
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# --- Configuration & Setup ---
|
| 33 |
+
|
| 34 |
+
# Configure Logging
|
| 35 |
+
logging.basicConfig(
|
| 36 |
+
level=logging.INFO,
|
| 37 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 38 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
| 39 |
+
)
|
| 40 |
+
logger = logging.getLogger(__name__)
|
| 41 |
+
|
| 42 |
+
# Suppress Warnings
|
| 43 |
+
logging.getLogger("rasterio._env").setLevel(logging.ERROR)
|
| 44 |
+
warnings.simplefilter("ignore", NotGeoreferencedWarning)
|
| 45 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 46 |
+
|
| 47 |
+
def set_seed(seed: int = 42):
|
| 48 |
+
"""Sets the seed for reproducibility across random, numpy, and torch."""
|
| 49 |
+
random.seed(seed)
|
| 50 |
+
np.random.seed(seed)
|
| 51 |
+
torch.manual_seed(seed)
|
| 52 |
+
if torch.cuda.is_available():
|
| 53 |
+
torch.cuda.manual_seed_all(seed)
|
| 54 |
+
|
| 55 |
+
def get_training_transforms() -> A.Compose:
|
| 56 |
+
"""Returns the albumentations training pipeline."""
|
| 57 |
+
return A.Compose([
|
| 58 |
+
A.ElasticTransform(p=0.25),
|
| 59 |
+
A.RandomRotate90(p=0.5),
|
| 60 |
+
A.Flip(p=0.5),
|
| 61 |
+
A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5)
|
| 62 |
+
])
|
| 63 |
+
|
| 64 |
+
# --- Helper Classes ---
|
| 65 |
+
|
| 66 |
+
class MetricTracker:
|
| 67 |
+
"""Accumulates targets and predictions to calculate epoch-level metrics."""
|
| 68 |
+
def __init__(self):
|
| 69 |
+
self.reset()
|
| 70 |
+
|
| 71 |
+
def reset(self):
|
| 72 |
+
self.all_targets = []
|
| 73 |
+
self.all_predictions = []
|
| 74 |
+
self.total_loss = 0.0
|
| 75 |
+
self.steps = 0
|
| 76 |
+
|
| 77 |
+
def update(self, loss: float, targets: torch.Tensor, probabilities: torch.Tensor):
|
| 78 |
+
self.total_loss += loss
|
| 79 |
+
self.steps += 1
|
| 80 |
+
# Store detached cpu numpy arrays to avoid VRAM leaks
|
| 81 |
+
self.all_targets.extend(torch.argmax(targets, dim=1).detach().cpu().numpy())
|
| 82 |
+
self.all_predictions.extend(torch.argmax(probabilities, dim=1).detach().cpu().numpy())
|
| 83 |
+
|
| 84 |
+
def compute(self) -> Dict[str, float]:
|
| 85 |
+
"""Calculates aggregate metrics for the accumulated data."""
|
| 86 |
+
if not self.all_targets:
|
| 87 |
+
return {}
|
| 88 |
+
|
| 89 |
+
# Calculate Confusion Matrix elements
|
| 90 |
+
tn, fp, fn, tp = confusion_matrix(self.all_targets, self.all_predictions, labels=[0, 1]).ravel()
|
| 91 |
+
|
| 92 |
+
return {
|
| 93 |
+
"Loss": self.total_loss / max(self.steps, 1),
|
| 94 |
+
"Accuracy": accuracy_score(self.all_targets, self.all_predictions),
|
| 95 |
+
"Specificity": tn / (tn + fp) if (tn + fp) != 0 else 0.0,
|
| 96 |
+
"Sensitivity": recall_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0),
|
| 97 |
+
"F1": f1_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0),
|
| 98 |
+
"MCC": matthews_corrcoef(self.all_targets, self.all_predictions),
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
class MethaneTrainer:
|
| 102 |
+
"""
|
| 103 |
+
Handles the training lifecycle: Model setup, Training loop, Validation, and Checkpointing.
|
| 104 |
+
"""
|
| 105 |
+
def __init__(self, args: argparse.Namespace):
|
| 106 |
+
self.args = args
|
| 107 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 108 |
+
self.save_dir = Path(args.save_dir) / f'fold{args.test_fold}'
|
| 109 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 110 |
+
|
| 111 |
+
self.model = self._init_model()
|
| 112 |
+
self.optimizer, self.scheduler = self._init_optimizer()
|
| 113 |
+
self.criterion = self.task.criterion # Retrieved from the TerraTorch task
|
| 114 |
+
|
| 115 |
+
self.best_val_loss = float('inf')
|
| 116 |
+
|
| 117 |
+
logger.info(f"Trainer initialized on device: {self.device}")
|
| 118 |
+
|
| 119 |
+
def _init_model(self) -> nn.Module:
|
| 120 |
+
"""Initializes the TerraTorch Classification Task and Model."""
|
| 121 |
+
model_config = dict(
|
| 122 |
+
backbone="terramind_v1_base",
|
| 123 |
+
backbone_pretrained=True,
|
| 124 |
+
backbone_modalities=["S2L2A"],
|
| 125 |
+
backbone_merge_method="mean",
|
| 126 |
+
decoder="UperNetDecoder",
|
| 127 |
+
decoder_scale_modules=True,
|
| 128 |
+
decoder_channels=256,
|
| 129 |
+
num_classes=2,
|
| 130 |
+
head_dropout=0.3,
|
| 131 |
+
necks=[
|
| 132 |
+
{"name": "ReshapeTokensToImage", "remove_cls_token": False},
|
| 133 |
+
{"name": "SelectIndices", "indices": [2, 5, 8, 11]},
|
| 134 |
+
],
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
self.task = ClassificationTask(
|
| 138 |
+
model_args=model_config,
|
| 139 |
+
model_factory="EncoderDecoderFactory",
|
| 140 |
+
loss="ce",
|
| 141 |
+
lr=self.args.lr,
|
| 142 |
+
ignore_index=-1,
|
| 143 |
+
optimizer="AdamW",
|
| 144 |
+
optimizer_hparams={"weight_decay": self.args.weight_decay},
|
| 145 |
+
)
|
| 146 |
+
self.task.configure_models()
|
| 147 |
+
self.task.configure_losses()
|
| 148 |
+
return self.task.model.to(self.device)
|
| 149 |
+
|
| 150 |
+
def _init_optimizer(self):
|
| 151 |
+
optimizer = optim.AdamW(
|
| 152 |
+
self.model.parameters(),
|
| 153 |
+
lr=self.args.lr,
|
| 154 |
+
weight_decay=self.args.weight_decay
|
| 155 |
+
)
|
| 156 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 157 |
+
optimizer, mode='min', patience=5, verbose=True
|
| 158 |
+
)
|
| 159 |
+
return optimizer, scheduler
|
| 160 |
+
|
| 161 |
+
def run_epoch(self, dataloader: DataLoader, stage: str = "train") -> Dict[str, float]:
|
| 162 |
+
"""Runs a single epoch for either training or validation."""
|
| 163 |
+
is_train = stage == "train"
|
| 164 |
+
self.model.train() if is_train else self.model.eval()
|
| 165 |
+
|
| 166 |
+
tracker = MetricTracker()
|
| 167 |
+
|
| 168 |
+
# Context manager: enable grad only if training
|
| 169 |
+
with torch.set_grad_enabled(is_train):
|
| 170 |
+
pbar = tqdm(dataloader, desc=f" {stage.capitalize()}", leave=False)
|
| 171 |
+
|
| 172 |
+
for batch in pbar:
|
| 173 |
+
inputs = batch['S2L2A'].to(self.device)
|
| 174 |
+
targets = batch['label'].to(self.device)
|
| 175 |
+
|
| 176 |
+
# Forward Pass
|
| 177 |
+
outputs = self.model(x={"S2L2A": inputs})
|
| 178 |
+
probabilities = torch.softmax(outputs.output, dim=1)
|
| 179 |
+
loss = self.criterion(probabilities, targets)
|
| 180 |
+
|
| 181 |
+
if is_train:
|
| 182 |
+
self.optimizer.zero_grad()
|
| 183 |
+
loss.backward()
|
| 184 |
+
self.optimizer.step()
|
| 185 |
+
|
| 186 |
+
# Update metrics
|
| 187 |
+
tracker.update(loss.item(), targets, probabilities)
|
| 188 |
+
|
| 189 |
+
# Update progress bar description with live loss
|
| 190 |
+
pbar.set_postfix(loss=f"{loss.item():.4f}")
|
| 191 |
+
|
| 192 |
+
return tracker.compute()
|
| 193 |
+
|
| 194 |
+
def save_checkpoint(self, filename: str):
|
| 195 |
+
path = self.save_dir / filename
|
| 196 |
+
torch.save(self.model.state_dict(), path)
|
| 197 |
+
logger.info(f"Saved model to {path}")
|
| 198 |
+
|
| 199 |
+
def log_to_csv(self, epoch: int, train_metrics: Dict, val_metrics: Dict):
|
| 200 |
+
"""Appends metrics to the CSV log file."""
|
| 201 |
+
csv_path = self.save_dir / 'train_val_metrics.csv'
|
| 202 |
+
file_exists = csv_path.exists()
|
| 203 |
+
|
| 204 |
+
# Define headers based on metric keys
|
| 205 |
+
headers = ['Epoch'] + [f'Train_{k}' for k in train_metrics.keys()] + [f'Val_{k}' for k in val_metrics.keys()]
|
| 206 |
+
|
| 207 |
+
with open(csv_path, mode='a', newline='') as f:
|
| 208 |
+
writer = csv.writer(f)
|
| 209 |
+
if not file_exists:
|
| 210 |
+
writer.writerow(headers)
|
| 211 |
+
|
| 212 |
+
row = [epoch] + list(train_metrics.values()) + list(val_metrics.values())
|
| 213 |
+
writer.writerow(row)
|
| 214 |
+
|
| 215 |
+
def fit(self, train_loader: DataLoader, val_loader: DataLoader):
|
| 216 |
+
"""Main training entry point."""
|
| 217 |
+
logger.info(f"Starting training for {self.args.epochs} epochs...")
|
| 218 |
+
start_time = time.time()
|
| 219 |
+
|
| 220 |
+
for epoch in range(1, self.args.epochs + 1):
|
| 221 |
+
logger.info(f"Epoch {epoch}/{self.args.epochs}")
|
| 222 |
+
|
| 223 |
+
# Run Training & Validation
|
| 224 |
+
train_metrics = self.run_epoch(train_loader, stage="train")
|
| 225 |
+
val_metrics = self.run_epoch(val_loader, stage="validate")
|
| 226 |
+
|
| 227 |
+
# Scheduler Step
|
| 228 |
+
self.scheduler.step(val_metrics['Loss'])
|
| 229 |
+
|
| 230 |
+
# Logging
|
| 231 |
+
self.log_to_csv(epoch, train_metrics, val_metrics)
|
| 232 |
+
logger.info(
|
| 233 |
+
f"Train Loss: {train_metrics['Loss']:.4f} | "
|
| 234 |
+
f"Val Loss: {val_metrics['Loss']:.4f} | "
|
| 235 |
+
f"Val F1: {val_metrics['F1']:.4f}"
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# Save Best Model
|
| 239 |
+
if val_metrics['Loss'] < self.best_val_loss:
|
| 240 |
+
self.best_val_loss = val_metrics['Loss']
|
| 241 |
+
self.save_checkpoint("best_model.pth")
|
| 242 |
+
logger.info(f"--> New best model (Val Loss: {self.best_val_loss:.4f})")
|
| 243 |
+
|
| 244 |
+
# End of training
|
| 245 |
+
self.save_checkpoint("final_model.pth")
|
| 246 |
+
logger.info(f"Training finished in {time.time() - start_time:.2f}s")
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# --- Data Utilities ---
|
| 250 |
+
|
| 251 |
+
def get_data_loaders(args) -> Tuple[DataLoader, DataLoader]:
|
| 252 |
+
"""Prepares DataModule and returns Train/Val loaders."""
|
| 253 |
+
|
| 254 |
+
# Read Excel and Filter Folds
|
| 255 |
+
try:
|
| 256 |
+
df = pd.read_csv(args.excel_file) if args.excel_file.endswith('.csv') else pd.read_excel(args.excel_file)
|
| 257 |
+
except Exception as e:
|
| 258 |
+
logger.error(f"Failed to load summary file: {e}")
|
| 259 |
+
raise
|
| 260 |
+
|
| 261 |
+
# Determine training pool (all folds except test_fold)
|
| 262 |
+
all_folds = range(1, args.num_folds + 1)
|
| 263 |
+
train_pool_folds = [f for f in all_folds if f != args.test_fold]
|
| 264 |
+
|
| 265 |
+
# Filter filenames
|
| 266 |
+
df_filtered = df[df['Fold'].isin(train_pool_folds)]
|
| 267 |
+
if df_filtered.empty:
|
| 268 |
+
raise ValueError(f"No data found for folds {train_pool_folds}. Check 'Fold' column in Excel.")
|
| 269 |
+
|
| 270 |
+
paths = df_filtered['Filename'].tolist()
|
| 271 |
+
|
| 272 |
+
# 80/20 Split
|
| 273 |
+
train_paths, val_paths = train_test_split(paths, test_size=0.2, random_state=args.seed)
|
| 274 |
+
|
| 275 |
+
logger.info(f"Data Split - Train: {len(train_paths)}, Val: {len(val_paths)} (Test Fold: {args.test_fold})")
|
| 276 |
+
|
| 277 |
+
# Initialize DataModule
|
| 278 |
+
datamodule = MethaneClassificationDataModule(
|
| 279 |
+
data_root=args.root_dir,
|
| 280 |
+
excel_file=args.excel_file,
|
| 281 |
+
batch_size=args.batch_size,
|
| 282 |
+
paths=train_paths,
|
| 283 |
+
train_transform=get_training_transforms(),
|
| 284 |
+
val_transform=None,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Create Loaders
|
| 288 |
+
datamodule.paths = train_paths
|
| 289 |
+
datamodule.setup(stage="fit")
|
| 290 |
+
train_loader = datamodule.train_dataloader()
|
| 291 |
+
|
| 292 |
+
datamodule.paths = val_paths
|
| 293 |
+
datamodule.setup(stage="validate")
|
| 294 |
+
val_loader = datamodule.val_dataloader()
|
| 295 |
+
|
| 296 |
+
return train_loader, val_loader
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# --- Main Execution ---
|
| 300 |
+
|
| 301 |
+
def parse_args():
|
| 302 |
+
parser = argparse.ArgumentParser(description="Methane Classification Training with TerraTorch")
|
| 303 |
+
|
| 304 |
+
# Paths
|
| 305 |
+
parser.add_argument('--root_dir', type=str, required=True, help='Root directory for satellite images')
|
| 306 |
+
parser.add_argument('--excel_file', type=str, required=True, help='Path to summary Excel/CSV file')
|
| 307 |
+
parser.add_argument('--save_dir', type=str, default='./checkpoints', help='Directory to save outputs')
|
| 308 |
+
|
| 309 |
+
# Training Config
|
| 310 |
+
parser.add_argument('--epochs', type=int, default=100)
|
| 311 |
+
parser.add_argument('--batch_size', type=int, default=8)
|
| 312 |
+
parser.add_argument('--lr', type=float, default=1e-5)
|
| 313 |
+
parser.add_argument('--weight_decay', type=float, default=0.05)
|
| 314 |
+
parser.add_argument('--num_folds', type=int, default=5)
|
| 315 |
+
parser.add_argument('--test_fold', type=int, default=2, help='Fold ID to hold out for testing')
|
| 316 |
+
parser.add_argument('--seed', type=int, default=42)
|
| 317 |
+
|
| 318 |
+
return parser.parse_args()
|
| 319 |
+
|
| 320 |
+
if __name__ == "__main__":
|
| 321 |
+
args = parse_args()
|
| 322 |
+
set_seed(args.seed)
|
| 323 |
+
|
| 324 |
+
# Prepare Data
|
| 325 |
+
train_loader, val_loader = get_data_loaders(args)
|
| 326 |
+
|
| 327 |
+
# Initialize Trainer and Start
|
| 328 |
+
trainer = MethaneTrainer(args)
|
| 329 |
+
trainer.fit(train_loader, val_loader)
|
classification_with_text/calculate_embeddings.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from sentence_transformers import SentenceTransformer
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
def extract_caption(text_block):
|
| 8 |
+
for line in text_block.splitlines():
|
| 9 |
+
if "CAPTION:" in line.upper():
|
| 10 |
+
return line.split("CAPTION:")[-1].strip()
|
| 11 |
+
return ""
|
| 12 |
+
|
| 13 |
+
def load_captions_from_files(json_files):
|
| 14 |
+
all_paths = []
|
| 15 |
+
all_captions = []
|
| 16 |
+
|
| 17 |
+
for json_path in tqdm(json_files, desc="Reading files"):
|
| 18 |
+
with open(json_path, 'r', encoding='utf-8') as f:
|
| 19 |
+
data = json.load(f)
|
| 20 |
+
|
| 21 |
+
for img_path, outer_list in data.items():
|
| 22 |
+
if not outer_list or not outer_list[0]:
|
| 23 |
+
continue
|
| 24 |
+
text_block = outer_list[0][0]
|
| 25 |
+
caption = extract_caption(text_block)
|
| 26 |
+
if caption:
|
| 27 |
+
all_paths.append(img_path)
|
| 28 |
+
all_captions.append(caption)
|
| 29 |
+
|
| 30 |
+
return all_paths, all_captions
|
| 31 |
+
|
| 32 |
+
def compute_and_save_embeddings(json_files, output_csv):
|
| 33 |
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 34 |
+
image_paths, captions = load_captions_from_files(json_files)
|
| 35 |
+
|
| 36 |
+
if not captions:
|
| 37 |
+
print("No valid captions found across input files.")
|
| 38 |
+
return
|
| 39 |
+
|
| 40 |
+
embeddings = model.encode(captions, show_progress_bar=True)
|
| 41 |
+
df = pd.DataFrame(embeddings)
|
| 42 |
+
df.insert(0, "image_path", image_paths)
|
| 43 |
+
df.to_csv(output_csv, index=False)
|
| 44 |
+
print(f"Saved {len(df)} embeddings from {len(json_files)} files to {output_csv}")
|
| 45 |
+
|
| 46 |
+
# Example usage
|
| 47 |
+
if __name__ == "__main__":
|
| 48 |
+
import glob
|
| 49 |
+
# Collect multiple JSON files
|
| 50 |
+
files = glob.glob("./MBD_text/*.json") # or provide manually
|
| 51 |
+
compute_and_save_embeddings(files, "combined_caption_embeddings.csv")
|
classification_with_text/combined_caption_embeddings.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
classification_with_text/script/methan_text_dataset.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import rasterio
|
| 3 |
+
import torch
|
| 4 |
+
from torchgeo.datasets import NonGeoDataset
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
class MethaneTextDataset(NonGeoDataset):
|
| 12 |
+
def __init__(self, root_dir, paths, captions, transform=None):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.root_dir = root_dir
|
| 15 |
+
self.transform = transform
|
| 16 |
+
self.data_paths = []
|
| 17 |
+
self.captions_dict = captions
|
| 18 |
+
|
| 19 |
+
# Collect paths for labelbinary.tif and sCube.tif in selected folders
|
| 20 |
+
for folder_name in paths:
|
| 21 |
+
subdir_path = os.path.join(root_dir, folder_name)
|
| 22 |
+
if os.path.isdir(subdir_path):
|
| 23 |
+
filename_tokens = subdir_path.split("/")
|
| 24 |
+
label_path = os.path.join(subdir_path, 'labelbinary.tif')
|
| 25 |
+
scube_path = os.path.join(subdir_path, 'sCube.tif')
|
| 26 |
+
if os.path.exists(label_path) and os.path.exists(scube_path):
|
| 27 |
+
self.data_paths.append((label_path, scube_path))
|
| 28 |
+
else:
|
| 29 |
+
print(f"Warning: Missing files in {subdir_path}. Expected labelbinary.tif and sCube.tif.")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def __len__(self):
|
| 33 |
+
return len(self.data_paths)
|
| 34 |
+
|
| 35 |
+
def __getitem__(self, idx):
|
| 36 |
+
label_path, scube_path = self.data_paths[idx]
|
| 37 |
+
filename_tokens = label_path.split("/")
|
| 38 |
+
folder_name = filename_tokens[-2]
|
| 39 |
+
|
| 40 |
+
# Load the label image (single band)
|
| 41 |
+
with rasterio.open(label_path) as label_src:
|
| 42 |
+
label_image = label_src.read(1) # Shape: [512, 512]
|
| 43 |
+
|
| 44 |
+
# Load the sCube image (multi-band), drop the first band
|
| 45 |
+
with rasterio.open(scube_path) as scube_src:
|
| 46 |
+
scube_image = scube_src.read() # Shape: [13, 512, 512]
|
| 47 |
+
scube_image = scube_image[1:, :, :] # Drop first band → Shape: [12, 512, 512]
|
| 48 |
+
|
| 49 |
+
# print(label_image.shape)
|
| 50 |
+
# print(scube_image.shape)
|
| 51 |
+
# Convert to PyTorch tensors
|
| 52 |
+
scube_tensor = torch.from_numpy(scube_image).float() # Shape: [12, 512, 512]
|
| 53 |
+
label_tensor = torch.from_numpy(label_image).float() # Shape: [512, 512]
|
| 54 |
+
|
| 55 |
+
# Resize to [12, 224, 224] and [224, 224] respectively
|
| 56 |
+
scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
|
| 57 |
+
label_tensor = F.interpolate(label_tensor.unsqueeze(0).unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0)
|
| 58 |
+
|
| 59 |
+
label_tensor = label_tensor.clip(0, 1) # Clip values to [0, 1]
|
| 60 |
+
scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0) # Replace NaNs with 0
|
| 61 |
+
# normalized_tensor = min_max_normalize(scube_tensor)
|
| 62 |
+
# Convert labels to binary
|
| 63 |
+
contains_methane = (label_tensor > 0).any().long()
|
| 64 |
+
|
| 65 |
+
# Convert to one-hot encoding
|
| 66 |
+
one_hot_label = F.one_hot(contains_methane, num_classes=2).float()
|
| 67 |
+
|
| 68 |
+
# Apply transformations (if any)
|
| 69 |
+
if self.transform:
|
| 70 |
+
transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0)))
|
| 71 |
+
scube_tensor = transformed['image'].transpose(2, 0, 1) # Convert back to [C, H, W]
|
| 72 |
+
|
| 73 |
+
if folder_name in self.captions_dict:
|
| 74 |
+
caption = self.captions_dict[folder_name]
|
| 75 |
+
else:
|
| 76 |
+
# If the folder name is not in the captions_dict, set a default caption or None
|
| 77 |
+
caption = 'No caption'
|
| 78 |
+
# caption = self.captions_dict[folder_name]
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
return {'S2L2A': scube_tensor, 'label': one_hot_label, 'caption': caption}
|
classification_with_text/script/methane_text_datamodule.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import rasterio
|
| 3 |
+
import torch
|
| 4 |
+
from torchgeo.datasets import NonGeoDataset
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torchgeo.datamodules import NonGeoDataModule
|
| 7 |
+
|
| 8 |
+
from methan_text_dataset import MethaneTextDataset
|
| 9 |
+
|
| 10 |
+
class MethaneTextDataModule(NonGeoDataModule):
|
| 11 |
+
"""
|
| 12 |
+
A DataModule for handling MethaneClassificationDataset
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
data_root: str,
|
| 18 |
+
paths: list,
|
| 19 |
+
captions: list,
|
| 20 |
+
batch_size: int = 8,
|
| 21 |
+
num_workers: int = 0,
|
| 22 |
+
train_transform: callable = None,
|
| 23 |
+
val_transform: callable = None,
|
| 24 |
+
test_transform: callable = None,
|
| 25 |
+
**kwargs
|
| 26 |
+
):
|
| 27 |
+
super().__init__(MethaneTextDataset, batch_size, num_workers, **kwargs)
|
| 28 |
+
|
| 29 |
+
self.data_root = data_root
|
| 30 |
+
self.paths = paths
|
| 31 |
+
self.captions = captions
|
| 32 |
+
self.train_transform = train_transform
|
| 33 |
+
self.val_transform = val_transform
|
| 34 |
+
self.test_transform = test_transform
|
| 35 |
+
|
| 36 |
+
def setup(self, stage: str = None):
|
| 37 |
+
if stage in ("fit", "train"):
|
| 38 |
+
self.train_dataset = MethaneTextDataset(
|
| 39 |
+
root_dir=self.data_root,
|
| 40 |
+
paths=self.paths,
|
| 41 |
+
captions=self.captions,
|
| 42 |
+
transform=self.train_transform,
|
| 43 |
+
)
|
| 44 |
+
if stage in ("fit", "validate", "val"):
|
| 45 |
+
self.val_dataset = MethaneTextDataset(
|
| 46 |
+
root_dir=self.data_root,
|
| 47 |
+
paths=self.paths,
|
| 48 |
+
captions=self.captions,
|
| 49 |
+
transform=self.val_transform,
|
| 50 |
+
)
|
| 51 |
+
if stage in ("test", "predict"):
|
| 52 |
+
self.test_dataset = MethaneTextDataset(
|
| 53 |
+
root_dir=self.data_root,
|
| 54 |
+
paths=self.paths,
|
| 55 |
+
captions=self.captions,
|
| 56 |
+
transform=self.test_transform,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def train_dataloader(self):
|
| 60 |
+
return DataLoader(
|
| 61 |
+
self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def val_dataloader(self):
|
| 65 |
+
return DataLoader(
|
| 66 |
+
self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def test_dataloader(self):
|
| 70 |
+
return DataLoader(
|
| 71 |
+
self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
|
| 72 |
+
)
|
classification_with_text/script/train_text.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
import csv
|
| 4 |
+
import random
|
| 5 |
+
import warnings
|
| 6 |
+
import time
|
| 7 |
+
import json
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from functools import partial
|
| 10 |
+
from typing import Dict, List, Tuple, Any, Optional
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.optim as optim
|
| 17 |
+
import albumentations as A
|
| 18 |
+
from torch.utils.data import DataLoader
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
from sklearn.model_selection import train_test_split
|
| 21 |
+
from sklearn.metrics import (
|
| 22 |
+
accuracy_score, recall_score, f1_score, matthews_corrcoef, confusion_matrix
|
| 23 |
+
)
|
| 24 |
+
from rasterio.errors import NotGeoreferencedWarning
|
| 25 |
+
from sentence_transformers import SentenceTransformer
|
| 26 |
+
|
| 27 |
+
# --- CRITICAL IMPORTS ---
|
| 28 |
+
import terramind
|
| 29 |
+
from terratorch.tasks import ClassificationTask
|
| 30 |
+
from terratorch.registry import TERRATORCH_BACKBONE_REGISTRY, TERRATORCH_DECODER_REGISTRY
|
| 31 |
+
from terramind.models.terramind_register import build_terrammind_vit
|
| 32 |
+
|
| 33 |
+
# Local Imports
|
| 34 |
+
from methane_text_datamodule import MethaneTextDataModule
|
| 35 |
+
|
| 36 |
+
# --- Configuration & Setup ---
|
| 37 |
+
|
| 38 |
+
logging.basicConfig(
|
| 39 |
+
level=logging.INFO,
|
| 40 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 41 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
| 42 |
+
)
|
| 43 |
+
logger = logging.getLogger(__name__)
|
| 44 |
+
|
| 45 |
+
logging.getLogger("rasterio._env").setLevel(logging.ERROR)
|
| 46 |
+
warnings.simplefilter("ignore", NotGeoreferencedWarning)
|
| 47 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 48 |
+
|
| 49 |
+
# --- Global Constants ---
|
| 50 |
+
PRETRAINED_BANDS = {
|
| 51 |
+
'untok_sen2l2a@224': [
|
| 52 |
+
"COASTAL_AEROSOL", "BLUE", "GREEN", "RED", "RED_EDGE_1", "RED_EDGE_2",
|
| 53 |
+
"RED_EDGE_3", "NIR_BROAD", "NIR_NARROW", "WATER_VAPOR", "SWIR_1", "SWIR_2",
|
| 54 |
+
]
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
def set_seed(seed: int = 42):
|
| 58 |
+
random.seed(seed)
|
| 59 |
+
np.random.seed(seed)
|
| 60 |
+
torch.manual_seed(seed)
|
| 61 |
+
if torch.cuda.is_available():
|
| 62 |
+
torch.cuda.manual_seed_all(seed)
|
| 63 |
+
|
| 64 |
+
def get_training_transforms() -> A.Compose:
|
| 65 |
+
return A.Compose([
|
| 66 |
+
A.ElasticTransform(p=0.25),
|
| 67 |
+
A.RandomRotate90(p=0.5),
|
| 68 |
+
A.Flip(p=0.5),
|
| 69 |
+
A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5)
|
| 70 |
+
])
|
| 71 |
+
|
| 72 |
+
# --- Custom Model Components (From Notebook) ---
|
| 73 |
+
|
| 74 |
+
# Initialize Sentence Transformer globally to avoid reloading
|
| 75 |
+
try:
|
| 76 |
+
EMBB_MODEL = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
| 77 |
+
# Move to GPU if available for faster encoding during training if needed,
|
| 78 |
+
# though usage in forward() implies dynamic encoding.
|
| 79 |
+
if torch.cuda.is_available():
|
| 80 |
+
EMBB_MODEL = EMBB_MODEL.to("cuda")
|
| 81 |
+
except Exception as e:
|
| 82 |
+
logger.warning(f"Could not load SentenceTransformer: {e}")
|
| 83 |
+
EMBB_MODEL = None
|
| 84 |
+
|
| 85 |
+
class TerraMindWithText(nn.Module):
|
| 86 |
+
def __init__(self, terramind_kwargs: dict):
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.terramind = build_terrammind_vit(
|
| 89 |
+
variant='terramind_v1_base',
|
| 90 |
+
encoder_depth=12,
|
| 91 |
+
dim=768,
|
| 92 |
+
num_heads=12,
|
| 93 |
+
mlp_ratio=4,
|
| 94 |
+
qkv_bias=False,
|
| 95 |
+
proj_bias=False,
|
| 96 |
+
mlp_bias=False,
|
| 97 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 98 |
+
act_layer=nn.SiLU,
|
| 99 |
+
gated_mlp=True,
|
| 100 |
+
pretrained_bands=PRETRAINED_BANDS,
|
| 101 |
+
**terramind_kwargs
|
| 102 |
+
)
|
| 103 |
+
self.out_channels = [768] * 12
|
| 104 |
+
# self.project = nn.Linear(768 + 512, 768*192) # Referenced in notebook but seemingly unused in forward
|
| 105 |
+
|
| 106 |
+
def forward(self, x, captions):
|
| 107 |
+
vision_features = self.terramind(x) # shape: (batch_size, 768)
|
| 108 |
+
|
| 109 |
+
# Encode captions
|
| 110 |
+
# Note: embb_model.encode returns numpy or tensor. Ensure it is on correct device.
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
captions_embed = EMBB_MODEL.encode(captions, convert_to_tensor=True, show_progress_bar=False)
|
| 113 |
+
|
| 114 |
+
# Ensure dimensionality matches what decoder expects (Squeeze if necessary, though encode usually returns [B, D])
|
| 115 |
+
if len(captions_embed.shape) == 3:
|
| 116 |
+
captions_embed = captions_embed.squeeze()
|
| 117 |
+
|
| 118 |
+
return vision_features + [captions_embed]
|
| 119 |
+
|
| 120 |
+
@TERRATORCH_BACKBONE_REGISTRY.register
|
| 121 |
+
def terramind_v1_base_with_text(**kwargs):
|
| 122 |
+
return TerraMindWithText(terramind_kwargs=kwargs)
|
| 123 |
+
|
| 124 |
+
@TERRATORCH_DECODER_REGISTRY.register
|
| 125 |
+
class SimpleDecoder(nn.Module):
|
| 126 |
+
includes_head = True
|
| 127 |
+
|
| 128 |
+
def __init__(self, input_dim=768, num_classes=2, caption_dim=384):
|
| 129 |
+
super().__init__()
|
| 130 |
+
# Handle input_dim if passed as list (common in TerraTorch)
|
| 131 |
+
dim = input_dim[0] if isinstance(input_dim, (list, tuple)) else input_dim
|
| 132 |
+
|
| 133 |
+
self.image_conv = nn.Sequential(
|
| 134 |
+
nn.Conv2d(dim, 512, kernel_size=3, padding=1),
|
| 135 |
+
nn.BatchNorm2d(512),
|
| 136 |
+
nn.ReLU(inplace=True),
|
| 137 |
+
nn.Dropout2d(0.3),
|
| 138 |
+
nn.Conv2d(512, 256, kernel_size=3, padding=1),
|
| 139 |
+
nn.BatchNorm2d(256),
|
| 140 |
+
nn.ReLU(inplace=True),
|
| 141 |
+
nn.Dropout2d(0.3)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
self.caption_mlp = nn.Sequential(
|
| 145 |
+
nn.Linear(caption_dim, 512),
|
| 146 |
+
nn.ReLU(inplace=True),
|
| 147 |
+
nn.Dropout(0.3),
|
| 148 |
+
nn.Linear(512, 256),
|
| 149 |
+
nn.ReLU(inplace=True),
|
| 150 |
+
nn.Dropout(0.3)
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
self.cross_attention = nn.MultiheadAttention(
|
| 154 |
+
embed_dim=256, num_heads=8, dropout=0.1, batch_first=True
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
self.fusion_conv = nn.Sequential(
|
| 158 |
+
nn.Conv2d(512, 256, kernel_size=3, padding=1),
|
| 159 |
+
nn.BatchNorm2d(256),
|
| 160 |
+
nn.ReLU(inplace=True),
|
| 161 |
+
nn.Dropout2d(0.3),
|
| 162 |
+
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
| 163 |
+
nn.BatchNorm2d(128),
|
| 164 |
+
nn.ReLU(inplace=True),
|
| 165 |
+
nn.Dropout2d(0.3)
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
self.conv_head = nn.Sequential(
|
| 169 |
+
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
| 170 |
+
nn.BatchNorm2d(64),
|
| 171 |
+
nn.ReLU(inplace=True),
|
| 172 |
+
nn.Dropout2d(0.3),
|
| 173 |
+
nn.Conv2d(64, 1, kernel_size=1)
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
self.out_channels = 1
|
| 177 |
+
|
| 178 |
+
def forward(self, features: list[torch.Tensor]) -> torch.Tensor:
|
| 179 |
+
# features list contains: [vision_feat_0, ..., vision_feat_11, caption_embed]
|
| 180 |
+
caption_embed = features[-1] # [B, 384]
|
| 181 |
+
image_features = features[:12]
|
| 182 |
+
|
| 183 |
+
# Average vision tokens
|
| 184 |
+
x = torch.stack(image_features, dim=1).mean(dim=1) # [B, 196, 768]
|
| 185 |
+
|
| 186 |
+
B, N, C = x.shape
|
| 187 |
+
H = W = int(N ** 0.5)
|
| 188 |
+
|
| 189 |
+
x = x.permute(0, 2, 1).view(B, C, H, W) # [B, 768, 14, 14]
|
| 190 |
+
img_features = self.image_conv(x) # [B, 256, 14, 14]
|
| 191 |
+
|
| 192 |
+
# Ensure caption embed has batch dim
|
| 193 |
+
if caption_embed.dim() == 1:
|
| 194 |
+
caption_embed = caption_embed.unsqueeze(0)
|
| 195 |
+
|
| 196 |
+
caption_features = self.caption_mlp(caption_embed) # [B, 256]
|
| 197 |
+
|
| 198 |
+
# Expand caption to spatial dims
|
| 199 |
+
caption_spatial = caption_features.unsqueeze(-1).unsqueeze(-1)
|
| 200 |
+
caption_spatial = caption_spatial.expand(B, -1, H, W) # [B, 256, 14, 14]
|
| 201 |
+
|
| 202 |
+
# Fuse
|
| 203 |
+
fused_features = torch.cat([img_features, caption_spatial], dim=1) # [B, 512, 14, 14]
|
| 204 |
+
fused = self.fusion_conv(fused_features) # [B, 128, 14, 14]
|
| 205 |
+
|
| 206 |
+
output = self.conv_head(fused) # [B, 1, 14, 14]
|
| 207 |
+
return output
|
| 208 |
+
|
| 209 |
+
# --- Helper Classes ---
|
| 210 |
+
|
| 211 |
+
class MetricTracker:
|
| 212 |
+
def __init__(self):
|
| 213 |
+
self.reset()
|
| 214 |
+
|
| 215 |
+
def reset(self):
|
| 216 |
+
self.all_targets = []
|
| 217 |
+
self.all_predictions = []
|
| 218 |
+
self.total_loss = 0.0
|
| 219 |
+
self.steps = 0
|
| 220 |
+
|
| 221 |
+
def update(self, loss: float, targets: torch.Tensor, probabilities: torch.Tensor):
|
| 222 |
+
self.total_loss += loss
|
| 223 |
+
self.steps += 1
|
| 224 |
+
self.all_targets.extend(torch.argmax(targets, dim=1).detach().cpu().numpy())
|
| 225 |
+
self.all_predictions.extend(torch.argmax(probabilities, dim=1).detach().cpu().numpy())
|
| 226 |
+
|
| 227 |
+
def compute(self) -> Dict[str, float]:
|
| 228 |
+
if not self.all_targets:
|
| 229 |
+
return {}
|
| 230 |
+
|
| 231 |
+
tn, fp, fn, tp = confusion_matrix(self.all_targets, self.all_predictions, labels=[0, 1]).ravel()
|
| 232 |
+
|
| 233 |
+
return {
|
| 234 |
+
"Loss": self.total_loss / max(self.steps, 1),
|
| 235 |
+
"Accuracy": accuracy_score(self.all_targets, self.all_predictions),
|
| 236 |
+
"Specificity": tn / (tn + fp) if (tn + fp) != 0 else 0.0,
|
| 237 |
+
"Sensitivity": recall_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0),
|
| 238 |
+
"F1": f1_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0),
|
| 239 |
+
"MCC": matthews_corrcoef(self.all_targets, self.all_predictions),
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
class MethaneTextTrainer:
|
| 243 |
+
def __init__(self, args: argparse.Namespace):
|
| 244 |
+
self.args = args
|
| 245 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 246 |
+
self.save_dir = Path(args.save_dir) / f'fold{args.test_fold}'
|
| 247 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 248 |
+
|
| 249 |
+
self.model = self._init_model()
|
| 250 |
+
self.optimizer, self.scheduler = self._init_optimizer()
|
| 251 |
+
self.criterion = self.task.criterion
|
| 252 |
+
self.best_val_loss = float('inf')
|
| 253 |
+
|
| 254 |
+
logger.info(f"Trainer initialized on device: {self.device}")
|
| 255 |
+
|
| 256 |
+
def _init_model(self) -> nn.Module:
|
| 257 |
+
model_args = dict(
|
| 258 |
+
backbone="terramind_v1_base_with_text",
|
| 259 |
+
backbone_pretrained=True,
|
| 260 |
+
backbone_modalities=["S2L2A"],
|
| 261 |
+
backbone_merge_method="mean",
|
| 262 |
+
num_classes=2,
|
| 263 |
+
head_dropout=0.3,
|
| 264 |
+
decoder="SimpleDecoder",
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
self.task = ClassificationTask(
|
| 268 |
+
model_args=model_args,
|
| 269 |
+
model_factory="EncoderDecoderFactory",
|
| 270 |
+
loss="ce",
|
| 271 |
+
lr=self.args.lr,
|
| 272 |
+
ignore_index=-1,
|
| 273 |
+
optimizer="AdamW",
|
| 274 |
+
optimizer_hparams={"weight_decay": self.args.weight_decay},
|
| 275 |
+
)
|
| 276 |
+
self.task.configure_models()
|
| 277 |
+
self.task.configure_losses()
|
| 278 |
+
return self.task.model.to(self.device)
|
| 279 |
+
|
| 280 |
+
def _init_optimizer(self):
|
| 281 |
+
optimizer = optim.AdamW(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
|
| 282 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, verbose=True)
|
| 283 |
+
return optimizer, scheduler
|
| 284 |
+
|
| 285 |
+
def run_epoch(self, dataloader: DataLoader, stage: str = "train") -> Dict[str, float]:
|
| 286 |
+
is_train = stage == "train"
|
| 287 |
+
self.model.train() if is_train else self.model.eval()
|
| 288 |
+
tracker = MetricTracker()
|
| 289 |
+
|
| 290 |
+
with torch.set_grad_enabled(is_train):
|
| 291 |
+
pbar = tqdm(dataloader, desc=f" {stage.capitalize()}", leave=False)
|
| 292 |
+
for batch in pbar:
|
| 293 |
+
# Prepare Inputs
|
| 294 |
+
inputs = batch['S2L2A'].to(self.device)
|
| 295 |
+
captions = batch['caption'] # List of strings
|
| 296 |
+
targets = batch['label'].to(self.device)
|
| 297 |
+
|
| 298 |
+
# Forward Pass (Note: passing captions explicitly)
|
| 299 |
+
# The Task wrapper might expect x dict, but our custom backbone forward handles 'captions'
|
| 300 |
+
outputs = self.model(x={"S2L2A": inputs}, captions=captions)
|
| 301 |
+
probabilities = torch.softmax(outputs.output, dim=1)
|
| 302 |
+
loss = self.criterion(probabilities, targets)
|
| 303 |
+
|
| 304 |
+
if is_train:
|
| 305 |
+
self.optimizer.zero_grad()
|
| 306 |
+
loss.backward()
|
| 307 |
+
self.optimizer.step()
|
| 308 |
+
|
| 309 |
+
tracker.update(loss.item(), targets, probabilities)
|
| 310 |
+
pbar.set_postfix(loss=f"{loss.item():.4f}")
|
| 311 |
+
|
| 312 |
+
return tracker.compute()
|
| 313 |
+
|
| 314 |
+
def log_to_csv(self, epoch: int, train_metrics: Dict, val_metrics: Dict):
|
| 315 |
+
csv_path = self.save_dir / 'train_val_metrics.csv'
|
| 316 |
+
headers = ['Epoch'] + [f'Train_{k}' for k in train_metrics.keys()] + [f'Val_{k}' for k in val_metrics.keys()]
|
| 317 |
+
|
| 318 |
+
with open(csv_path, mode='a', newline='') as f:
|
| 319 |
+
writer = csv.writer(f)
|
| 320 |
+
if not csv_path.exists():
|
| 321 |
+
writer.writerow(headers)
|
| 322 |
+
writer.writerow([epoch] + list(train_metrics.values()) + list(val_metrics.values()))
|
| 323 |
+
|
| 324 |
+
def fit(self, train_loader: DataLoader, val_loader: DataLoader):
|
| 325 |
+
logger.info(f"Starting training for {self.args.epochs} epochs...")
|
| 326 |
+
start_time = time.time()
|
| 327 |
+
|
| 328 |
+
for epoch in range(1, self.args.epochs + 1):
|
| 329 |
+
logger.info(f"Epoch {epoch}/{self.args.epochs}")
|
| 330 |
+
|
| 331 |
+
train_metrics = self.run_epoch(train_loader, stage="train")
|
| 332 |
+
val_metrics = self.run_epoch(val_loader, stage="validate")
|
| 333 |
+
|
| 334 |
+
self.scheduler.step(val_metrics['Loss'])
|
| 335 |
+
self.log_to_csv(epoch, train_metrics, val_metrics)
|
| 336 |
+
|
| 337 |
+
logger.info(f"Train Loss: {train_metrics['Loss']:.4f} | Val Loss: {val_metrics['Loss']:.4f} | Val F1: {val_metrics['F1']:.4f}")
|
| 338 |
+
|
| 339 |
+
if val_metrics['Loss'] < self.best_val_loss:
|
| 340 |
+
self.best_val_loss = val_metrics['Loss']
|
| 341 |
+
torch.save(self.model.state_dict(), self.save_dir / "best_model.pth")
|
| 342 |
+
logger.info(f"--> New best model saved")
|
| 343 |
+
|
| 344 |
+
torch.save(self.model.state_dict(), self.save_dir / "final_model.pth")
|
| 345 |
+
logger.info(f"Training finished in {time.time() - start_time:.2f}s")
|
| 346 |
+
|
| 347 |
+
# --- Data Utilities ---
|
| 348 |
+
|
| 349 |
+
def read_captions(json_path: Path, captions_dict: Dict) -> Dict:
|
| 350 |
+
"""Reads captions from JSON and populates dictionary."""
|
| 351 |
+
if not json_path.exists():
|
| 352 |
+
logger.warning(f"Caption file not found: {json_path}")
|
| 353 |
+
return captions_dict
|
| 354 |
+
|
| 355 |
+
try:
|
| 356 |
+
with open(json_path, "r", encoding="utf-8") as file:
|
| 357 |
+
data = json.load(file)
|
| 358 |
+
|
| 359 |
+
for file_path_str, text_list in data.items():
|
| 360 |
+
if text_list and isinstance(text_list, list) and text_list[0]:
|
| 361 |
+
text_content = text_list[0][0]
|
| 362 |
+
caption_start = text_content.find("CAPTION:")
|
| 363 |
+
if caption_start != -1:
|
| 364 |
+
caption = text_content[caption_start + len("CAPTION:"):].strip()
|
| 365 |
+
# Extract folder name (assumes specific directory structure from notebook)
|
| 366 |
+
# "path\\to\\folder\\image.ext" -> "folder"
|
| 367 |
+
path_parts = file_path_str.replace("\\", "/").split("/")
|
| 368 |
+
if len(path_parts) >= 2:
|
| 369 |
+
last_directory = path_parts[-2]
|
| 370 |
+
captions_dict[last_directory] = caption
|
| 371 |
+
except Exception as e:
|
| 372 |
+
logger.error(f"Error reading captions {json_path}: {e}")
|
| 373 |
+
|
| 374 |
+
return captions_dict
|
| 375 |
+
|
| 376 |
+
def get_paths_for_fold(excel_file: str, folds: List[int]) -> List[str]:
|
| 377 |
+
df = pd.read_excel(excel_file)
|
| 378 |
+
df_filtered = df[df['Fold'].isin(folds)]
|
| 379 |
+
return df_filtered['Filename'].tolist()
|
| 380 |
+
|
| 381 |
+
def get_data_loaders(args) -> Tuple[DataLoader, DataLoader]:
|
| 382 |
+
# 1. Load Captions
|
| 383 |
+
captions_dict = {}
|
| 384 |
+
captions_dict = read_captions(Path(args.methane_captions), captions_dict)
|
| 385 |
+
captions_dict = read_captions(Path(args.no_methane_captions), captions_dict)
|
| 386 |
+
logger.info(f"Loaded {len(captions_dict)} captions.")
|
| 387 |
+
|
| 388 |
+
# 2. Get File Paths
|
| 389 |
+
all_folds = range(1, args.num_folds + 1)
|
| 390 |
+
train_pool_folds = [f for f in all_folds if f != args.test_fold]
|
| 391 |
+
paths = get_paths_for_fold(args.excel_file, train_pool_folds)
|
| 392 |
+
|
| 393 |
+
# 3. Split
|
| 394 |
+
train_paths, val_paths = train_test_split(paths, test_size=0.2, random_state=args.seed)
|
| 395 |
+
logger.info(f"Train: {len(train_paths)}, Val: {len(val_paths)}")
|
| 396 |
+
|
| 397 |
+
# 4. DataModule
|
| 398 |
+
datamodule = MethaneTextDataModule(
|
| 399 |
+
data_root=args.root_dir,
|
| 400 |
+
paths=paths, # Initial dummy
|
| 401 |
+
captions=captions_dict,
|
| 402 |
+
train_transform=get_training_transforms(),
|
| 403 |
+
batch_size=args.batch_size,
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
# Train Loader
|
| 407 |
+
datamodule.paths = train_paths
|
| 408 |
+
datamodule.setup(stage="train")
|
| 409 |
+
train_loader = datamodule.train_dataloader()
|
| 410 |
+
|
| 411 |
+
# Val Loader
|
| 412 |
+
datamodule.paths = val_paths
|
| 413 |
+
datamodule.setup(stage="validate")
|
| 414 |
+
val_loader = datamodule.val_dataloader()
|
| 415 |
+
|
| 416 |
+
return train_loader, val_loader
|
| 417 |
+
|
| 418 |
+
# --- Main Execution ---
|
| 419 |
+
|
| 420 |
+
def parse_args():
|
| 421 |
+
parser = argparse.ArgumentParser(description="Methane Text-Multimodal Training")
|
| 422 |
+
|
| 423 |
+
# Data Paths
|
| 424 |
+
parser.add_argument('--root_dir', type=str, required=True, help='Root directory for images')
|
| 425 |
+
parser.add_argument('--excel_file', type=str, required=True, help='Path to Summary Excel')
|
| 426 |
+
parser.add_argument('--methane_captions', type=str, required=True, help='Path to Methane JSON captions')
|
| 427 |
+
parser.add_argument('--no_methane_captions', type=str, required=True, help='Path to No-Methane JSON captions')
|
| 428 |
+
parser.add_argument('--save_dir', type=str, default='./checkpoints', help='Output directory')
|
| 429 |
+
|
| 430 |
+
# Hyperparameters
|
| 431 |
+
parser.add_argument('--epochs', type=int, default=100)
|
| 432 |
+
parser.add_argument('--batch_size', type=int, default=4)
|
| 433 |
+
parser.add_argument('--lr', type=float, default=5e-5)
|
| 434 |
+
parser.add_argument('--weight_decay', type=float, default=0.05)
|
| 435 |
+
parser.add_argument('--num_folds', type=int, default=5)
|
| 436 |
+
parser.add_argument('--test_fold', type=int, default=2)
|
| 437 |
+
parser.add_argument('--seed', type=int, default=42)
|
| 438 |
+
|
| 439 |
+
return parser.parse_args()
|
| 440 |
+
|
| 441 |
+
if __name__ == "__main__":
|
| 442 |
+
args = parse_args()
|
| 443 |
+
set_seed(args.seed)
|
| 444 |
+
|
| 445 |
+
train_loader, val_loader = get_data_loaders(args)
|
| 446 |
+
|
| 447 |
+
trainer = MethaneTextTrainer(args)
|
| 448 |
+
trainer.fit(train_loader, val_loader)
|
intuition1_classification_finetuning/config/methane_simulated_datamodule.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import albumentations as A
|
| 3 |
+
from sklearn.model_selection import train_test_split
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from torchgeo.datamodules import NonGeoDataModule
|
| 6 |
+
from methane_simulated_dataset import MethaneSimulatedDataset
|
| 7 |
+
|
| 8 |
+
class MethaneSimulatedDataModule(NonGeoDataModule):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
data_root: str,
|
| 12 |
+
excel_file: str,
|
| 13 |
+
batch_size: int = 8,
|
| 14 |
+
num_workers: int = 0,
|
| 15 |
+
val_split: float = 0.2,
|
| 16 |
+
seed: int = 42,
|
| 17 |
+
test_fold: int = 4, # Default test fold from your script
|
| 18 |
+
num_folds: int = 5,
|
| 19 |
+
**kwargs
|
| 20 |
+
):
|
| 21 |
+
super().__init__(MethaneSimulatedDataset, batch_size, num_workers, **kwargs)
|
| 22 |
+
|
| 23 |
+
self.data_root = data_root
|
| 24 |
+
self.excel_file = excel_file
|
| 25 |
+
self.batch_size = batch_size
|
| 26 |
+
self.num_workers = num_workers
|
| 27 |
+
self.val_split = val_split
|
| 28 |
+
self.seed = seed
|
| 29 |
+
self.test_fold = test_fold
|
| 30 |
+
self.num_folds = num_folds
|
| 31 |
+
|
| 32 |
+
self.train_paths = []
|
| 33 |
+
self.val_paths = []
|
| 34 |
+
|
| 35 |
+
def _get_training_transforms(self):
|
| 36 |
+
return A.Compose([
|
| 37 |
+
A.ElasticTransform(p=0.25),
|
| 38 |
+
A.RandomRotate90(p=0.5),
|
| 39 |
+
A.Flip(p=0.5),
|
| 40 |
+
A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5)
|
| 41 |
+
])
|
| 42 |
+
|
| 43 |
+
def _get_simulated_paths(self, paths):
|
| 44 |
+
"""Logic to rename files to I1/TOA format"""
|
| 45 |
+
simulated_paths = []
|
| 46 |
+
for path in paths:
|
| 47 |
+
try:
|
| 48 |
+
tokens = path.split('_')
|
| 49 |
+
if len(tokens) >= 5:
|
| 50 |
+
simulated_path = f"{tokens[0]}_toarefl_{tokens[3]}_{tokens[4]}"
|
| 51 |
+
simulated_paths.append(simulated_path)
|
| 52 |
+
else:
|
| 53 |
+
simulated_paths.append(path)
|
| 54 |
+
except Exception:
|
| 55 |
+
simulated_paths.append(path)
|
| 56 |
+
return simulated_paths
|
| 57 |
+
|
| 58 |
+
def setup(self, stage: str = None):
|
| 59 |
+
# 1. Read Excel
|
| 60 |
+
try:
|
| 61 |
+
df = pd.read_excel(self.excel_file)
|
| 62 |
+
except Exception as e:
|
| 63 |
+
raise RuntimeError(f"Failed to load excel: {e}")
|
| 64 |
+
|
| 65 |
+
# 2. Filter Folds (Exclude test_fold)
|
| 66 |
+
all_folds = list(range(1, self.num_folds + 1))
|
| 67 |
+
train_pool_folds = [f for f in all_folds if f != self.test_fold]
|
| 68 |
+
|
| 69 |
+
df_filtered = df[df['Fold'].isin(train_pool_folds)]
|
| 70 |
+
raw_paths = df_filtered['Filename'].tolist()
|
| 71 |
+
|
| 72 |
+
# 3. Apply Path Renaming Logic
|
| 73 |
+
paths = self._get_simulated_paths(raw_paths)
|
| 74 |
+
|
| 75 |
+
# 4. Train/Val Split
|
| 76 |
+
self.train_paths, self.val_paths = train_test_split(
|
| 77 |
+
paths,
|
| 78 |
+
test_size=self.val_split,
|
| 79 |
+
random_state=self.seed
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# 5. Instantiate Datasets
|
| 83 |
+
if stage in ("fit", "train"):
|
| 84 |
+
self.train_dataset = MethaneSimulatedDataset(
|
| 85 |
+
root_dir=self.data_root,
|
| 86 |
+
excel_file=self.excel_file,
|
| 87 |
+
paths=self.train_paths,
|
| 88 |
+
transform=self._get_training_transforms(),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
if stage in ("fit", "validate", "val"):
|
| 92 |
+
self.val_dataset = MethaneSimulatedDataset(
|
| 93 |
+
root_dir=self.data_root,
|
| 94 |
+
excel_file=self.excel_file,
|
| 95 |
+
paths=self.val_paths,
|
| 96 |
+
transform=None,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def train_dataloader(self):
|
| 100 |
+
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,
|
| 101 |
+
num_workers=self.num_workers, drop_last=True)
|
| 102 |
+
|
| 103 |
+
def val_dataloader(self):
|
| 104 |
+
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False,
|
| 105 |
+
num_workers=self.num_workers, drop_last=True)
|
| 106 |
+
|
| 107 |
+
# def on_after_batch_transfer(self, batch, dataloader_idx):
|
| 108 |
+
# # 1. Run TorchGeo default (expects 'image')
|
| 109 |
+
# batch = super().on_after_batch_transfer(batch, dataloader_idx)
|
| 110 |
+
|
| 111 |
+
# # 2. Wrap into TerraMind format {'S2L2A': ...}
|
| 112 |
+
# if 'image' in batch:
|
| 113 |
+
# s2_data = batch['image']
|
| 114 |
+
# batch['image'] = {'S2L2A': s2_data}
|
| 115 |
+
|
| 116 |
+
# return batch
|
intuition1_classification_finetuning/config/methane_simulated_dataset.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import rasterio
|
| 3 |
+
import torch
|
| 4 |
+
from torchgeo.datasets import NonGeoDataset
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
class MethaneSimulatedDataset(NonGeoDataset):
|
| 9 |
+
def __init__(self, root_dir, excel_file, paths, transform=None):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.root_dir = root_dir
|
| 12 |
+
self.transform = transform
|
| 13 |
+
self.data_paths = []
|
| 14 |
+
|
| 15 |
+
# Collect paths
|
| 16 |
+
for folder_name in paths:
|
| 17 |
+
subdir_path = os.path.join(root_dir, folder_name)
|
| 18 |
+
if os.path.isdir(subdir_path):
|
| 19 |
+
# Note: Filenames here seem to match the folder name based on your script
|
| 20 |
+
label_path = os.path.join(subdir_path, folder_name + '_mask.tif')
|
| 21 |
+
scube_path = os.path.join(subdir_path, folder_name + '_hsi.dat')
|
| 22 |
+
|
| 23 |
+
if os.path.exists(label_path) and os.path.exists(scube_path):
|
| 24 |
+
self.data_paths.append((label_path, scube_path))
|
| 25 |
+
|
| 26 |
+
def __len__(self):
|
| 27 |
+
return len(self.data_paths)
|
| 28 |
+
|
| 29 |
+
def __getitem__(self, idx):
|
| 30 |
+
label_path, scube_path = self.data_paths[idx]
|
| 31 |
+
|
| 32 |
+
# Load label
|
| 33 |
+
with rasterio.open(label_path) as label_src:
|
| 34 |
+
label_image = label_src.read()
|
| 35 |
+
|
| 36 |
+
# Load sCube (I1/TOA data)
|
| 37 |
+
with rasterio.open(scube_path) as scube_src:
|
| 38 |
+
scube_image = scube_src.read()
|
| 39 |
+
# Read only first 12 bands
|
| 40 |
+
scube_image = scube_image[:12, :, :]
|
| 41 |
+
|
| 42 |
+
# Convert to Tensors
|
| 43 |
+
scube_tensor = torch.from_numpy(scube_image).float()
|
| 44 |
+
label_tensor = torch.from_numpy(label_image).float()
|
| 45 |
+
|
| 46 |
+
# Resize
|
| 47 |
+
scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
|
| 48 |
+
label_tensor = F.interpolate(label_tensor.unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0)
|
| 49 |
+
|
| 50 |
+
label_tensor = label_tensor.clip(0, 1)
|
| 51 |
+
scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0)
|
| 52 |
+
|
| 53 |
+
# Convert labels to binary index (0 or 1)
|
| 54 |
+
contains_methane = (label_tensor > 0).any().long()
|
| 55 |
+
|
| 56 |
+
# Apply transformations
|
| 57 |
+
if self.transform:
|
| 58 |
+
# Albumentations expects [H, W, C]
|
| 59 |
+
transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0)))
|
| 60 |
+
scube_tensor = transformed['image'].transpose(2, 0, 1)
|
| 61 |
+
|
| 62 |
+
return {
|
| 63 |
+
'image': scube_tensor, # <--- Named 'image' for TorchGeo
|
| 64 |
+
'label': contains_methane, # <--- Index for CrossEntropy
|
| 65 |
+
'sample': scube_path
|
| 66 |
+
}
|
intuition1_classification_finetuning/config/train.yaml
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed_everything: 42
|
| 2 |
+
|
| 3 |
+
trainer:
|
| 4 |
+
accelerator: auto
|
| 5 |
+
strategy: auto
|
| 6 |
+
devices: 1
|
| 7 |
+
max_epochs: 100
|
| 8 |
+
default_root_dir: ./checkpoints_i1
|
| 9 |
+
callbacks:
|
| 10 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 11 |
+
init_args:
|
| 12 |
+
monitor: val/loss
|
| 13 |
+
mode: min
|
| 14 |
+
save_top_k: 1
|
| 15 |
+
filename: "best_model"
|
| 16 |
+
save_last: true
|
| 17 |
+
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
|
| 18 |
+
init_args:
|
| 19 |
+
logging_interval: epoch
|
| 20 |
+
|
| 21 |
+
model:
|
| 22 |
+
class_path: terratorch.tasks.ClassificationTask
|
| 23 |
+
init_args:
|
| 24 |
+
model_factory: EncoderDecoderFactory
|
| 25 |
+
loss: ce
|
| 26 |
+
ignore_index: -1
|
| 27 |
+
lr: 1.0e-5 # Top-level LR
|
| 28 |
+
|
| 29 |
+
optimizer: AdamW
|
| 30 |
+
optimizer_hparams:
|
| 31 |
+
weight_decay: 0.05
|
| 32 |
+
|
| 33 |
+
scheduler: ReduceLROnPlateau
|
| 34 |
+
scheduler_hparams:
|
| 35 |
+
mode: min
|
| 36 |
+
patience: 5
|
| 37 |
+
|
| 38 |
+
model_args:
|
| 39 |
+
backbone: terramind_v1_base
|
| 40 |
+
backbone_pretrained: true
|
| 41 |
+
backbone_modalities:
|
| 42 |
+
- S2L2A
|
| 43 |
+
backbone_merge_method: mean
|
| 44 |
+
|
| 45 |
+
decoder: UperNetDecoder
|
| 46 |
+
decoder_scale_modules: true
|
| 47 |
+
decoder_channels: 256
|
| 48 |
+
num_classes: 2
|
| 49 |
+
head_dropout: 0.3
|
| 50 |
+
|
| 51 |
+
necks:
|
| 52 |
+
- name: ReshapeTokensToImage
|
| 53 |
+
remove_cls_token: false
|
| 54 |
+
- name: SelectIndices
|
| 55 |
+
indices: [2, 5, 8, 11]
|
| 56 |
+
|
| 57 |
+
data:
|
| 58 |
+
class_path: methane_simulated_datamodule.MethaneSimulatedDataModule
|
| 59 |
+
init_args:
|
| 60 |
+
data_root: # Place the data root here
|
| 61 |
+
excel_file: ../../Methane_benchmark_patches_summary_v3.xlsx # Update this!
|
| 62 |
+
batch_size: 8
|
| 63 |
+
val_split: 0.2
|
| 64 |
+
seed: 42
|
| 65 |
+
test_fold: 4
|
| 66 |
+
num_folds: 5
|
intuition1_classification_finetuning/script/methane_simulated_datamodule.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import rasterio
|
| 3 |
+
import torch
|
| 4 |
+
from torchgeo.datasets import NonGeoDataset
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torchgeo.datamodules import NonGeoDataModule
|
| 7 |
+
from methane_simulated_dataset import MethaneSimulatedDataset
|
| 8 |
+
# from methane_classification_dataset import MethaneClassificationDataset
|
| 9 |
+
|
| 10 |
+
class MethaneSimulatedDataModule(NonGeoDataModule):
|
| 11 |
+
"""
|
| 12 |
+
A DataModule for handling MethaneClassificationDataset
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
data_root: str,
|
| 18 |
+
excel_file: str,
|
| 19 |
+
paths: list,
|
| 20 |
+
batch_size: int = 8,
|
| 21 |
+
num_workers: int = 0,
|
| 22 |
+
train_transform: callable = None,
|
| 23 |
+
val_transform: callable = None,
|
| 24 |
+
test_transform: callable = None,
|
| 25 |
+
**kwargs
|
| 26 |
+
):
|
| 27 |
+
super().__init__(MethaneSimulatedDataset, batch_size, num_workers, **kwargs)
|
| 28 |
+
|
| 29 |
+
self.data_root = data_root
|
| 30 |
+
self.excel_file = excel_file
|
| 31 |
+
self.paths = paths
|
| 32 |
+
self.train_transform = train_transform
|
| 33 |
+
self.val_transform = val_transform
|
| 34 |
+
self.test_transform = test_transform
|
| 35 |
+
|
| 36 |
+
def setup(self, stage: str = None):
|
| 37 |
+
if stage in ("fit", "train"):
|
| 38 |
+
self.train_dataset = MethaneSimulatedDataset(
|
| 39 |
+
root_dir=self.data_root,
|
| 40 |
+
excel_file=self.excel_file,
|
| 41 |
+
paths=self.paths,
|
| 42 |
+
transform=self.train_transform,
|
| 43 |
+
)
|
| 44 |
+
if stage in ("fit", "validate", "val"):
|
| 45 |
+
self.val_dataset = MethaneSimulatedDataset(
|
| 46 |
+
root_dir=self.data_root,
|
| 47 |
+
excel_file=self.excel_file,
|
| 48 |
+
paths=self.paths,
|
| 49 |
+
transform=self.val_transform,
|
| 50 |
+
)
|
| 51 |
+
if stage in ("test", "predict"):
|
| 52 |
+
self.test_dataset = MethaneSimulatedDataset(
|
| 53 |
+
root_dir=self.data_root,
|
| 54 |
+
excel_file=self.excel_file,
|
| 55 |
+
paths=self.paths,
|
| 56 |
+
transform=self.test_transform,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def train_dataloader(self):
|
| 60 |
+
return DataLoader(
|
| 61 |
+
self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def val_dataloader(self):
|
| 65 |
+
return DataLoader(
|
| 66 |
+
self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def test_dataloader(self):
|
| 70 |
+
return DataLoader(
|
| 71 |
+
self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
|
| 72 |
+
)
|
intuition1_classification_finetuning/script/methane_simulated_dataset.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import rasterio
|
| 3 |
+
import torch
|
| 4 |
+
from torchgeo.datasets import NonGeoDataset
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
def min_max_normalize(data, new_min=0, new_max=1):
|
| 11 |
+
data = np.array(data, dtype=np.float32) # Convert to NumPy array
|
| 12 |
+
|
| 13 |
+
# Handle NaN, Inf values
|
| 14 |
+
data = np.nan_to_num(data, nan=np.nanmin(data), posinf=np.max(data), neginf=np.min(data))
|
| 15 |
+
|
| 16 |
+
old_min, old_max = np.min(data), np.max(data)
|
| 17 |
+
|
| 18 |
+
if old_max == old_min: # Prevent division by zero
|
| 19 |
+
return np.full_like(data, new_min, dtype=np.float32) # Uniform array
|
| 20 |
+
|
| 21 |
+
return (data - old_min) / (old_max - old_min + 1e-10) * (new_max - new_min) + new_min
|
| 22 |
+
|
| 23 |
+
class MethaneSimulatedDataset(NonGeoDataset):
|
| 24 |
+
def __init__(self, root_dir, excel_file, paths, transform=None):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.root_dir = root_dir
|
| 27 |
+
self.transform = transform
|
| 28 |
+
self.data_paths = []
|
| 29 |
+
|
| 30 |
+
# Collect paths for labelbinary.tif and sCube.tif in selected folders
|
| 31 |
+
for folder_name in paths:
|
| 32 |
+
subdir_path = os.path.join(root_dir, folder_name)
|
| 33 |
+
if os.path.isdir(subdir_path):
|
| 34 |
+
label_path = os.path.join(subdir_path, folder_name + '_mask.tif')
|
| 35 |
+
scube_path = os.path.join(subdir_path, folder_name + '_hsi.dat')
|
| 36 |
+
|
| 37 |
+
if os.path.exists(label_path) and os.path.exists(scube_path):
|
| 38 |
+
self.data_paths.append((label_path, scube_path))
|
| 39 |
+
|
| 40 |
+
def __len__(self):
|
| 41 |
+
return len(self.data_paths)
|
| 42 |
+
|
| 43 |
+
def __getitem__(self, idx):
|
| 44 |
+
label_path, scube_path = self.data_paths[idx]
|
| 45 |
+
|
| 46 |
+
# Load the label image (single band)
|
| 47 |
+
with rasterio.open(label_path) as label_src:
|
| 48 |
+
label_image = label_src.read() # Shape: [512, 512]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Load the sCube image (multi-band), drop the first band
|
| 52 |
+
with rasterio.open(scube_path) as scube_src:
|
| 53 |
+
scube_image = scube_src.read() # Shape: [13, 512, 512]
|
| 54 |
+
|
| 55 |
+
# Read only the first 12 bands for testing purposes
|
| 56 |
+
# Map the bands later on
|
| 57 |
+
scube_image = scube_image[:12, :, :]
|
| 58 |
+
|
| 59 |
+
# Convert to PyTorch tensors
|
| 60 |
+
scube_tensor = torch.from_numpy(scube_image).float() # Shape: [12, 512, 512]
|
| 61 |
+
label_tensor = torch.from_numpy(label_image).float() # Shape: [512, 512]
|
| 62 |
+
|
| 63 |
+
# Resize to [12, 224, 224] and [224, 224] respectively
|
| 64 |
+
scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
|
| 65 |
+
label_tensor = F.interpolate(label_tensor.unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0)
|
| 66 |
+
|
| 67 |
+
label_tensor = label_tensor.clip(0, 1) # Clip values to [0, 1]
|
| 68 |
+
scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0) # Replace NaNs with 0
|
| 69 |
+
# normalized_tensor = min_max_normalize(scube_tensor)
|
| 70 |
+
# Convert labels to binary
|
| 71 |
+
contains_methane = (label_tensor > 0).any().long()
|
| 72 |
+
|
| 73 |
+
# Convert to one-hot encoding
|
| 74 |
+
one_hot_label = F.one_hot(contains_methane, num_classes=2).float()
|
| 75 |
+
|
| 76 |
+
# Apply transformations (if any)
|
| 77 |
+
if self.transform:
|
| 78 |
+
transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0)))
|
| 79 |
+
scube_tensor = transformed['image'].transpose(2, 0, 1) # Convert back to [C, H, W]
|
| 80 |
+
|
| 81 |
+
return {'S2L2A': scube_tensor, 'label': one_hot_label, 'sample': scube_path}
|
intuition1_classification_finetuning/script/train_simulated_I1.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
import csv
|
| 4 |
+
import random
|
| 5 |
+
import warnings
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, List, Tuple, Any, Optional
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.optim as optim
|
| 15 |
+
import albumentations as A
|
| 16 |
+
from torch.utils.data import DataLoader
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from sklearn.model_selection import train_test_split
|
| 19 |
+
from sklearn.metrics import (
|
| 20 |
+
accuracy_score, recall_score, f1_score, matthews_corrcoef, confusion_matrix
|
| 21 |
+
)
|
| 22 |
+
from rasterio.errors import NotGeoreferencedWarning
|
| 23 |
+
|
| 24 |
+
# --- CRITICAL IMPORTS ---
|
| 25 |
+
import terramind
|
| 26 |
+
from terratorch.tasks import ClassificationTask
|
| 27 |
+
|
| 28 |
+
# Local Imports
|
| 29 |
+
from methane_simulated_datamodule import MethaneSimulatedDataModule
|
| 30 |
+
|
| 31 |
+
# --- Configuration & Setup ---
|
| 32 |
+
|
| 33 |
+
logging.basicConfig(
|
| 34 |
+
level=logging.INFO,
|
| 35 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 36 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
| 37 |
+
)
|
| 38 |
+
logger = logging.getLogger(__name__)
|
| 39 |
+
|
| 40 |
+
logging.getLogger("rasterio._env").setLevel(logging.ERROR)
|
| 41 |
+
warnings.simplefilter("ignore", NotGeoreferencedWarning)
|
| 42 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 43 |
+
|
| 44 |
+
def set_seed(seed: int = 42):
|
| 45 |
+
random.seed(seed)
|
| 46 |
+
np.random.seed(seed)
|
| 47 |
+
torch.manual_seed(seed)
|
| 48 |
+
if torch.cuda.is_available():
|
| 49 |
+
torch.cuda.manual_seed_all(seed)
|
| 50 |
+
|
| 51 |
+
def get_training_transforms() -> A.Compose:
|
| 52 |
+
return A.Compose([
|
| 53 |
+
A.ElasticTransform(p=0.25),
|
| 54 |
+
A.RandomRotate90(p=0.5),
|
| 55 |
+
A.Flip(p=0.5),
|
| 56 |
+
A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5)
|
| 57 |
+
])
|
| 58 |
+
|
| 59 |
+
# --- Path Utilities ---
|
| 60 |
+
|
| 61 |
+
def get_simulated_paths(paths: List[str]) -> List[str]:
|
| 62 |
+
"""
|
| 63 |
+
Modifies filenames to match the I1/TOA naming convention.
|
| 64 |
+
Converts 'ang2015..._S2_...' -> 'ang2015..._toarefl_...'
|
| 65 |
+
"""
|
| 66 |
+
simulated_paths = []
|
| 67 |
+
for path in paths:
|
| 68 |
+
try:
|
| 69 |
+
tokens = path.split('_')
|
| 70 |
+
# Logic: {ID}_toarefl_{Coord1}_{Coord2}
|
| 71 |
+
# Adjusts original filename tokens to target format
|
| 72 |
+
if len(tokens) >= 5:
|
| 73 |
+
simulated_path = f"{tokens[0]}_toarefl_{tokens[3]}_{tokens[4]}"
|
| 74 |
+
simulated_paths.append(simulated_path)
|
| 75 |
+
else:
|
| 76 |
+
simulated_paths.append(path)
|
| 77 |
+
except Exception as e:
|
| 78 |
+
logger.warning(f"Could not parse path {path}: {e}")
|
| 79 |
+
simulated_paths.append(path)
|
| 80 |
+
return simulated_paths
|
| 81 |
+
|
| 82 |
+
def get_paths_for_fold(excel_file: str, folds: List[int]) -> List[str]:
|
| 83 |
+
try:
|
| 84 |
+
df = pd.read_excel(excel_file)
|
| 85 |
+
df_filtered = df[df['Fold'].isin(folds)]
|
| 86 |
+
return df_filtered['Filename'].tolist()
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logger.error(f"Error reading Excel file: {e}")
|
| 89 |
+
raise
|
| 90 |
+
|
| 91 |
+
# --- Helper Classes ---
|
| 92 |
+
|
| 93 |
+
class MetricTracker:
|
| 94 |
+
def __init__(self):
|
| 95 |
+
self.reset()
|
| 96 |
+
|
| 97 |
+
def reset(self):
|
| 98 |
+
self.all_targets = []
|
| 99 |
+
self.all_predictions = []
|
| 100 |
+
self.total_loss = 0.0
|
| 101 |
+
self.steps = 0
|
| 102 |
+
|
| 103 |
+
def update(self, loss: float, targets: torch.Tensor, probabilities: torch.Tensor):
|
| 104 |
+
self.total_loss += loss
|
| 105 |
+
self.steps += 1
|
| 106 |
+
self.all_targets.extend(torch.argmax(targets, dim=1).detach().cpu().numpy())
|
| 107 |
+
self.all_predictions.extend(torch.argmax(probabilities, dim=1).detach().cpu().numpy())
|
| 108 |
+
|
| 109 |
+
def compute(self) -> Dict[str, float]:
|
| 110 |
+
if not self.all_targets:
|
| 111 |
+
return {}
|
| 112 |
+
|
| 113 |
+
tn, fp, fn, tp = confusion_matrix(self.all_targets, self.all_predictions, labels=[0, 1]).ravel()
|
| 114 |
+
|
| 115 |
+
return {
|
| 116 |
+
"Loss": self.total_loss / max(self.steps, 1),
|
| 117 |
+
"Accuracy": accuracy_score(self.all_targets, self.all_predictions),
|
| 118 |
+
"Specificity": tn / (tn + fp) if (tn + fp) != 0 else 0.0,
|
| 119 |
+
"Sensitivity": recall_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0),
|
| 120 |
+
"F1": f1_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0),
|
| 121 |
+
"MCC": matthews_corrcoef(self.all_targets, self.all_predictions),
|
| 122 |
+
"TP": int(tp), "TN": int(tn), "FP": int(fp), "FN": int(fn)
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
class TrainerI1:
|
| 126 |
+
def __init__(self, args: argparse.Namespace):
|
| 127 |
+
self.args = args
|
| 128 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 129 |
+
self.save_dir = Path(args.save_dir) / f'fold{args.test_fold}'
|
| 130 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 131 |
+
|
| 132 |
+
self.model = self._init_model()
|
| 133 |
+
self.optimizer, self.scheduler = self._init_optimizer()
|
| 134 |
+
self.criterion = self.task.criterion
|
| 135 |
+
self.best_val_loss = float('inf')
|
| 136 |
+
|
| 137 |
+
logger.info(f"Trainer initialized on device: {self.device}")
|
| 138 |
+
|
| 139 |
+
def _init_model(self) -> nn.Module:
|
| 140 |
+
model_args = dict(
|
| 141 |
+
backbone="terramind_v1_base",
|
| 142 |
+
backbone_pretrained=True,
|
| 143 |
+
backbone_modalities=["S2L2A"],
|
| 144 |
+
backbone_merge_method="mean",
|
| 145 |
+
decoder="UperNetDecoder",
|
| 146 |
+
decoder_scale_modules=True,
|
| 147 |
+
decoder_channels=256,
|
| 148 |
+
num_classes=2,
|
| 149 |
+
head_dropout=0.3,
|
| 150 |
+
necks=[
|
| 151 |
+
{"name": "ReshapeTokensToImage", "remove_cls_token": False},
|
| 152 |
+
{"name": "SelectIndices", "indices": [2, 5, 8, 11]},
|
| 153 |
+
{"name": "LearnedInterpolateToPyramidal"},
|
| 154 |
+
],
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
self.task = ClassificationTask(
|
| 158 |
+
model_args=model_args,
|
| 159 |
+
model_factory="EncoderDecoderFactory",
|
| 160 |
+
loss="ce",
|
| 161 |
+
lr=self.args.lr,
|
| 162 |
+
ignore_index=-1,
|
| 163 |
+
optimizer="AdamW",
|
| 164 |
+
optimizer_hparams={"weight_decay": self.args.weight_decay},
|
| 165 |
+
)
|
| 166 |
+
self.task.configure_models()
|
| 167 |
+
self.task.configure_losses()
|
| 168 |
+
return self.task.model.to(self.device)
|
| 169 |
+
|
| 170 |
+
def _init_optimizer(self):
|
| 171 |
+
optimizer = optim.AdamW(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
|
| 172 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, verbose=True)
|
| 173 |
+
return optimizer, scheduler
|
| 174 |
+
|
| 175 |
+
def run_epoch(self, dataloader: DataLoader, stage: str = "train") -> Dict[str, float]:
|
| 176 |
+
is_train = stage == "train"
|
| 177 |
+
self.model.train() if is_train else self.model.eval()
|
| 178 |
+
tracker = MetricTracker()
|
| 179 |
+
|
| 180 |
+
with torch.set_grad_enabled(is_train):
|
| 181 |
+
pbar = tqdm(dataloader, desc=f" {stage.capitalize()}", leave=False)
|
| 182 |
+
for batch in pbar:
|
| 183 |
+
inputs = batch['S2L2A'].to(self.device)
|
| 184 |
+
targets = batch['label'].to(self.device)
|
| 185 |
+
|
| 186 |
+
outputs = self.model(x={"S2L2A": inputs})
|
| 187 |
+
probabilities = torch.softmax(outputs.output, dim=1)
|
| 188 |
+
loss = self.criterion(probabilities, targets)
|
| 189 |
+
|
| 190 |
+
if is_train:
|
| 191 |
+
self.optimizer.zero_grad()
|
| 192 |
+
loss.backward()
|
| 193 |
+
self.optimizer.step()
|
| 194 |
+
|
| 195 |
+
tracker.update(loss.item(), targets, probabilities)
|
| 196 |
+
pbar.set_postfix(loss=f"{loss.item():.4f}")
|
| 197 |
+
|
| 198 |
+
return tracker.compute()
|
| 199 |
+
|
| 200 |
+
def fit(self, train_loader: DataLoader, val_loader: DataLoader):
|
| 201 |
+
logger.info(f"Starting training for {self.args.epochs} epochs...")
|
| 202 |
+
start_time = time.time()
|
| 203 |
+
|
| 204 |
+
# Initialize CSV logging
|
| 205 |
+
csv_path = self.save_dir / 'train_val_metrics.csv'
|
| 206 |
+
with open(csv_path, 'w', newline='') as f:
|
| 207 |
+
writer = csv.writer(f)
|
| 208 |
+
writer.writerow([
|
| 209 |
+
'Epoch', 'Train_Loss', 'Train_F1', 'Train_Acc',
|
| 210 |
+
'Val_Loss', 'Val_F1', 'Val_Acc', 'Val_Spec', 'Val_Sens'
|
| 211 |
+
])
|
| 212 |
+
|
| 213 |
+
for epoch in range(1, self.args.epochs + 1):
|
| 214 |
+
logger.info(f"Epoch {epoch}/{self.args.epochs}")
|
| 215 |
+
|
| 216 |
+
train_metrics = self.run_epoch(train_loader, stage="train")
|
| 217 |
+
val_metrics = self.run_epoch(val_loader, stage="validate")
|
| 218 |
+
|
| 219 |
+
self.scheduler.step(val_metrics['Loss'])
|
| 220 |
+
|
| 221 |
+
# Log to CSV
|
| 222 |
+
with open(csv_path, 'a', newline='') as f:
|
| 223 |
+
writer = csv.writer(f)
|
| 224 |
+
writer.writerow([
|
| 225 |
+
epoch,
|
| 226 |
+
train_metrics.get('Loss'), train_metrics.get('F1'), train_metrics.get('Accuracy'),
|
| 227 |
+
val_metrics.get('Loss'), val_metrics.get('F1'), val_metrics.get('Accuracy'),
|
| 228 |
+
val_metrics.get('Specificity'), val_metrics.get('Sensitivity')
|
| 229 |
+
])
|
| 230 |
+
|
| 231 |
+
logger.info(f"Train Loss: {train_metrics['Loss']:.4f} | Val Loss: {val_metrics['Loss']:.4f} | Val F1: {val_metrics['F1']:.4f}")
|
| 232 |
+
|
| 233 |
+
# Save Best Model
|
| 234 |
+
if val_metrics['Loss'] < self.best_val_loss:
|
| 235 |
+
self.best_val_loss = val_metrics['Loss']
|
| 236 |
+
torch.save(self.model.state_dict(), self.save_dir / "best_model.pth")
|
| 237 |
+
logger.info(f"--> New best model saved")
|
| 238 |
+
|
| 239 |
+
# Save Final Model
|
| 240 |
+
torch.save(self.model.state_dict(), self.save_dir / "final_model.pth")
|
| 241 |
+
logger.info(f"Training finished in {time.time() - start_time:.2f}s")
|
| 242 |
+
|
| 243 |
+
# --- Data Utilities ---
|
| 244 |
+
|
| 245 |
+
def get_data_loaders(args) -> Tuple[DataLoader, DataLoader]:
|
| 246 |
+
# 1. Determine Folds
|
| 247 |
+
all_folds = list(range(1, args.num_folds + 1))
|
| 248 |
+
train_pool_folds = [f for f in all_folds if f != args.test_fold]
|
| 249 |
+
|
| 250 |
+
# 2. Get Paths & Convert to TOA/I1 format
|
| 251 |
+
# Note: Using get_simulated_paths to transform names as done in the notebook
|
| 252 |
+
paths = get_paths_for_fold(args.excel_file, train_pool_folds)
|
| 253 |
+
paths = get_simulated_paths(paths)
|
| 254 |
+
|
| 255 |
+
# 3. Train/Val Split (80/20)
|
| 256 |
+
train_paths, val_paths = train_test_split(paths, test_size=0.2, random_state=args.seed)
|
| 257 |
+
|
| 258 |
+
logger.info(f"Data Split - Train: {len(train_paths)}, Val: {len(val_paths)} (Test Fold: {args.test_fold})")
|
| 259 |
+
|
| 260 |
+
# 4. Initialize DataModule
|
| 261 |
+
datamodule = MethaneSimulatedDataModule(
|
| 262 |
+
data_root=args.root_dir,
|
| 263 |
+
excel_file=args.excel_file,
|
| 264 |
+
batch_size=args.batch_size,
|
| 265 |
+
paths=paths, # Initial dummy
|
| 266 |
+
train_transform=get_training_transforms(),
|
| 267 |
+
val_transform=None,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# 5. Create Loaders
|
| 271 |
+
datamodule.paths = train_paths
|
| 272 |
+
datamodule.setup(stage="fit")
|
| 273 |
+
train_loader = datamodule.train_dataloader()
|
| 274 |
+
|
| 275 |
+
datamodule.paths = val_paths
|
| 276 |
+
datamodule.setup(stage="validate")
|
| 277 |
+
val_loader = datamodule.val_dataloader()
|
| 278 |
+
|
| 279 |
+
return train_loader, val_loader
|
| 280 |
+
|
| 281 |
+
# --- Main Execution ---
|
| 282 |
+
|
| 283 |
+
def parse_args():
|
| 284 |
+
parser = argparse.ArgumentParser(description="Methane I1 (TOA Refl) Training")
|
| 285 |
+
|
| 286 |
+
# Paths
|
| 287 |
+
parser.add_argument('--root_dir', type=str, required=True, help='Root directory for I1/TOA data')
|
| 288 |
+
parser.add_argument('--excel_file', type=str, required=True, help='Path to Summary Excel')
|
| 289 |
+
parser.add_argument('--save_dir', type=str, default='./checkpoints_i1', help='Output directory')
|
| 290 |
+
|
| 291 |
+
# Hyperparameters
|
| 292 |
+
parser.add_argument('--epochs', type=int, default=100)
|
| 293 |
+
parser.add_argument('--batch_size', type=int, default=1)
|
| 294 |
+
parser.add_argument('--lr', type=float, default=1e-5)
|
| 295 |
+
parser.add_argument('--weight_decay', type=float, default=0.05)
|
| 296 |
+
parser.add_argument('--num_folds', type=int, default=5)
|
| 297 |
+
parser.add_argument('--test_fold', type=int, default=4, help='Fold ID to hold out')
|
| 298 |
+
parser.add_argument('--seed', type=int, default=42)
|
| 299 |
+
|
| 300 |
+
return parser.parse_args()
|
| 301 |
+
|
| 302 |
+
if __name__ == "__main__":
|
| 303 |
+
args = parse_args()
|
| 304 |
+
set_seed(args.seed)
|
| 305 |
+
|
| 306 |
+
train_loader, val_loader = get_data_loaders(args)
|
| 307 |
+
|
| 308 |
+
trainer = TrainerI1(args)
|
| 309 |
+
trainer.fit(train_loader, val_loader)
|
sentinel2_classification_finetuning/config/methane_simulated_datamodule.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import albumentations as A
|
| 3 |
+
from sklearn.model_selection import train_test_split
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from torchgeo.datamodules import NonGeoDataModule
|
| 6 |
+
from methane_simulated_dataset import MethaneSimulatedDataset
|
| 7 |
+
|
| 8 |
+
class MethaneSimulatedDataModule(NonGeoDataModule):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
data_root: str,
|
| 12 |
+
excel_file: str,
|
| 13 |
+
batch_size: int = 8,
|
| 14 |
+
num_workers: int = 0,
|
| 15 |
+
val_split: float = 0.2,
|
| 16 |
+
seed: int = 42,
|
| 17 |
+
test_fold: int = 4,
|
| 18 |
+
num_folds: int = 5,
|
| 19 |
+
sim_tag: str = "toarefl", # <--- New arg for 'toarefl'/'boarefl'
|
| 20 |
+
**kwargs
|
| 21 |
+
):
|
| 22 |
+
super().__init__(MethaneSimulatedDataset, batch_size, num_workers, **kwargs)
|
| 23 |
+
|
| 24 |
+
self.data_root = data_root
|
| 25 |
+
self.excel_file = excel_file
|
| 26 |
+
self.batch_size = batch_size
|
| 27 |
+
self.num_workers = num_workers
|
| 28 |
+
self.val_split = val_split
|
| 29 |
+
self.seed = seed
|
| 30 |
+
self.test_fold = test_fold
|
| 31 |
+
self.num_folds = num_folds
|
| 32 |
+
self.sim_tag = sim_tag
|
| 33 |
+
|
| 34 |
+
self.train_paths = []
|
| 35 |
+
self.val_paths = []
|
| 36 |
+
|
| 37 |
+
def _get_training_transforms(self):
|
| 38 |
+
return A.Compose([
|
| 39 |
+
A.ElasticTransform(p=0.25),
|
| 40 |
+
A.RandomRotate90(p=0.5),
|
| 41 |
+
A.Flip(p=0.5),
|
| 42 |
+
A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5)
|
| 43 |
+
])
|
| 44 |
+
|
| 45 |
+
def _get_simulated_paths(self, paths):
|
| 46 |
+
"""Internal logic to rename files based on sim_tag"""
|
| 47 |
+
simulated_paths = []
|
| 48 |
+
for path in paths:
|
| 49 |
+
try:
|
| 50 |
+
tokens = path.split('_')
|
| 51 |
+
if len(tokens) >= 5:
|
| 52 |
+
# Logic: {ID}_{tag}_{Coord1}_{Coord2}
|
| 53 |
+
simulated_path = f"{tokens[0]}_{self.sim_tag}_{tokens[3]}_{tokens[4]}"
|
| 54 |
+
simulated_paths.append(simulated_path)
|
| 55 |
+
else:
|
| 56 |
+
simulated_paths.append(path)
|
| 57 |
+
except Exception:
|
| 58 |
+
simulated_paths.append(path)
|
| 59 |
+
return simulated_paths
|
| 60 |
+
|
| 61 |
+
def setup(self, stage: str = None):
|
| 62 |
+
# 1. Read Excel
|
| 63 |
+
try:
|
| 64 |
+
df = pd.read_excel(self.excel_file)
|
| 65 |
+
except Exception as e:
|
| 66 |
+
raise RuntimeError(f"Failed to load excel: {e}")
|
| 67 |
+
|
| 68 |
+
# 2. Filter Folds (Exclude test_fold)
|
| 69 |
+
all_folds = list(range(1, self.num_folds + 1))
|
| 70 |
+
train_pool_folds = [f for f in all_folds if f != self.test_fold]
|
| 71 |
+
|
| 72 |
+
df_filtered = df[df['Fold'].isin(train_pool_folds)]
|
| 73 |
+
raw_paths = df_filtered['Filename'].tolist()
|
| 74 |
+
|
| 75 |
+
# 3. Apply Path Renaming Logic
|
| 76 |
+
paths = self._get_simulated_paths(raw_paths)
|
| 77 |
+
|
| 78 |
+
# 4. Train/Val Split
|
| 79 |
+
self.train_paths, self.val_paths = train_test_split(
|
| 80 |
+
paths,
|
| 81 |
+
test_size=self.val_split,
|
| 82 |
+
random_state=self.seed
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# 5. Instantiate Datasets
|
| 86 |
+
if stage in ("fit", "train"):
|
| 87 |
+
self.train_dataset = MethaneSimulatedDataset(
|
| 88 |
+
root_dir=self.data_root,
|
| 89 |
+
excel_file=self.excel_file,
|
| 90 |
+
paths=self.train_paths,
|
| 91 |
+
transform=self._get_training_transforms(),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
if stage in ("fit", "validate", "val"):
|
| 95 |
+
self.val_dataset = MethaneSimulatedDataset(
|
| 96 |
+
root_dir=self.data_root,
|
| 97 |
+
excel_file=self.excel_file,
|
| 98 |
+
paths=self.val_paths,
|
| 99 |
+
transform=None,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def train_dataloader(self):
|
| 103 |
+
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,
|
| 104 |
+
num_workers=self.num_workers, drop_last=True)
|
| 105 |
+
|
| 106 |
+
def val_dataloader(self):
|
| 107 |
+
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False,
|
| 108 |
+
num_workers=self.num_workers, drop_last=True)
|
| 109 |
+
|
| 110 |
+
def on_after_batch_transfer(self, batch, dataloader_idx):
|
| 111 |
+
# 1. Run TorchGeo default (expects 'image')
|
| 112 |
+
batch = super().on_after_batch_transfer(batch, dataloader_idx)
|
| 113 |
+
|
| 114 |
+
# 2. Wrap into TerraMind format {'S2L2A': ...}
|
| 115 |
+
if 'image' in batch:
|
| 116 |
+
s2_data = batch['image']
|
| 117 |
+
batch['image'] = {'S2L2A': s2_data}
|
| 118 |
+
|
| 119 |
+
return batch
|
sentinel2_classification_finetuning/config/methane_simulated_dataset.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import rasterio
|
| 3 |
+
import torch
|
| 4 |
+
from torchgeo.datasets import NonGeoDataset
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
class MethaneSimulatedDataset(NonGeoDataset):
|
| 9 |
+
def __init__(self, root_dir, excel_file, paths, transform=None):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.root_dir = root_dir
|
| 12 |
+
self.transform = transform
|
| 13 |
+
self.data_paths = []
|
| 14 |
+
|
| 15 |
+
# Collect paths
|
| 16 |
+
for folder_name in paths:
|
| 17 |
+
subdir_path = os.path.join(root_dir, folder_name)
|
| 18 |
+
if os.path.isdir(subdir_path):
|
| 19 |
+
# Construct paths based on folder name
|
| 20 |
+
label_path = os.path.join(subdir_path, folder_name + '_mask.tif')
|
| 21 |
+
scube_path = os.path.join(subdir_path, folder_name + '_hsi.dat')
|
| 22 |
+
|
| 23 |
+
if os.path.exists(label_path) and os.path.exists(scube_path):
|
| 24 |
+
self.data_paths.append((label_path, scube_path))
|
| 25 |
+
|
| 26 |
+
def __len__(self):
|
| 27 |
+
return len(self.data_paths)
|
| 28 |
+
|
| 29 |
+
def __getitem__(self, idx):
|
| 30 |
+
label_path, scube_path = self.data_paths[idx]
|
| 31 |
+
|
| 32 |
+
# Load label
|
| 33 |
+
with rasterio.open(label_path) as label_src:
|
| 34 |
+
label_image = label_src.read()
|
| 35 |
+
|
| 36 |
+
# Load sCube (Try explicit ENVI driver first for .dat files)
|
| 37 |
+
try:
|
| 38 |
+
with rasterio.open(scube_path, driver='ENVI') as scube_src:
|
| 39 |
+
scube_image = scube_src.read()
|
| 40 |
+
scube_image = scube_image[:12, :, :] # Read first 12 bands
|
| 41 |
+
except Exception:
|
| 42 |
+
# Fallback if driver auto-detection is needed
|
| 43 |
+
with rasterio.open(scube_path) as scube_src:
|
| 44 |
+
scube_image = scube_src.read()
|
| 45 |
+
scube_image = scube_image[:12, :, :]
|
| 46 |
+
|
| 47 |
+
# Convert to Tensors
|
| 48 |
+
scube_tensor = torch.from_numpy(scube_image).float()
|
| 49 |
+
label_tensor = torch.from_numpy(label_image).float()
|
| 50 |
+
|
| 51 |
+
# Resize
|
| 52 |
+
scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
|
| 53 |
+
label_tensor = F.interpolate(label_tensor.unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0)
|
| 54 |
+
|
| 55 |
+
label_tensor = label_tensor.clip(0, 1)
|
| 56 |
+
scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0)
|
| 57 |
+
|
| 58 |
+
# Convert labels to binary index (0 or 1)
|
| 59 |
+
contains_methane = (label_tensor > 0).any().long()
|
| 60 |
+
|
| 61 |
+
# Apply transformations
|
| 62 |
+
if self.transform:
|
| 63 |
+
transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0)))
|
| 64 |
+
scube_tensor = transformed['image'].transpose(2, 0, 1)
|
| 65 |
+
|
| 66 |
+
return {
|
| 67 |
+
'image': scube_tensor, # <--- 'image' for TorchGeo
|
| 68 |
+
'label': contains_methane, # <--- Index for CE Loss
|
| 69 |
+
'sample': scube_path
|
| 70 |
+
}
|
sentinel2_classification_finetuning/config/train.yaml
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed_everything: 42
|
| 2 |
+
|
| 3 |
+
trainer:
|
| 4 |
+
accelerator: auto
|
| 5 |
+
strategy: auto
|
| 6 |
+
devices: 1
|
| 7 |
+
max_epochs: 100
|
| 8 |
+
default_root_dir: ./checkpoints_s2_simulated
|
| 9 |
+
callbacks:
|
| 10 |
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
| 11 |
+
init_args:
|
| 12 |
+
monitor: val/loss
|
| 13 |
+
mode: min
|
| 14 |
+
save_top_k: 1
|
| 15 |
+
filename: "best_model"
|
| 16 |
+
save_last: true
|
| 17 |
+
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
|
| 18 |
+
init_args:
|
| 19 |
+
logging_interval: epoch
|
| 20 |
+
|
| 21 |
+
model:
|
| 22 |
+
class_path: terratorch.tasks.ClassificationTask
|
| 23 |
+
init_args:
|
| 24 |
+
model_factory: EncoderDecoderFactory
|
| 25 |
+
loss: ce
|
| 26 |
+
ignore_index: -1
|
| 27 |
+
lr: 1.0e-5 # Top-level LR
|
| 28 |
+
|
| 29 |
+
optimizer: AdamW
|
| 30 |
+
optimizer_hparams:
|
| 31 |
+
weight_decay: 0.05
|
| 32 |
+
|
| 33 |
+
scheduler: ReduceLROnPlateau
|
| 34 |
+
scheduler_hparams:
|
| 35 |
+
mode: min
|
| 36 |
+
patience: 5
|
| 37 |
+
|
| 38 |
+
model_args:
|
| 39 |
+
backbone: terramind_v1_base
|
| 40 |
+
backbone_pretrained: true
|
| 41 |
+
backbone_modalities:
|
| 42 |
+
- S2L2A
|
| 43 |
+
backbone_merge_method: mean
|
| 44 |
+
|
| 45 |
+
decoder: UperNetDecoder
|
| 46 |
+
decoder_scale_modules: true
|
| 47 |
+
decoder_channels: 256
|
| 48 |
+
num_classes: 2
|
| 49 |
+
head_dropout: 0.3
|
| 50 |
+
|
| 51 |
+
necks:
|
| 52 |
+
- name: ReshapeTokensToImage
|
| 53 |
+
remove_cls_token: false
|
| 54 |
+
- name: SelectIndices
|
| 55 |
+
indices: [2, 5, 8, 11]
|
| 56 |
+
|
| 57 |
+
data:
|
| 58 |
+
class_path: methane_simulated_datamodule.MethaneSimulatedDataModule
|
| 59 |
+
init_args:
|
| 60 |
+
data_root: /path/to/data_root # UPDATE THIS
|
| 61 |
+
excel_file: ../../Methane_benchmark_patches_summary_v3.xlsx
|
| 62 |
+
batch_size: 8
|
| 63 |
+
val_split: 0.2
|
| 64 |
+
seed: 42
|
| 65 |
+
test_fold: 4
|
| 66 |
+
num_folds: 5
|
| 67 |
+
sim_tag: toarefl # Change to 'boarefl' if needed
|
sentinel2_classification_finetuning/script/inference_s2_simulated.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
import csv
|
| 4 |
+
import random
|
| 5 |
+
import warnings
|
| 6 |
+
import time
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, List, Tuple, Any, Optional
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import albumentations as A
|
| 16 |
+
from torch.utils.data import DataLoader
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from rasterio.errors import NotGeoreferencedWarning
|
| 19 |
+
|
| 20 |
+
# --- CRITICAL IMPORTS ---
|
| 21 |
+
import terramind
|
| 22 |
+
from terratorch.tasks import ClassificationTask
|
| 23 |
+
|
| 24 |
+
# Local Imports
|
| 25 |
+
from methane_simulated_datamodule import MethaneSimulatedDataModule
|
| 26 |
+
|
| 27 |
+
# --- Configuration & Setup ---
|
| 28 |
+
|
| 29 |
+
logging.basicConfig(
|
| 30 |
+
level=logging.INFO,
|
| 31 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 32 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
| 33 |
+
)
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
logging.getLogger("rasterio._env").setLevel(logging.ERROR)
|
| 37 |
+
warnings.simplefilter("ignore", NotGeoreferencedWarning)
|
| 38 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 39 |
+
|
| 40 |
+
def set_seed(seed: int = 42):
|
| 41 |
+
random.seed(seed)
|
| 42 |
+
np.random.seed(seed)
|
| 43 |
+
torch.manual_seed(seed)
|
| 44 |
+
if torch.cuda.is_available():
|
| 45 |
+
torch.cuda.manual_seed_all(seed)
|
| 46 |
+
|
| 47 |
+
def get_inference_transforms() -> A.Compose:
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
# --- Path Utilities (Crucial for Simulated Data) ---
|
| 51 |
+
|
| 52 |
+
def get_simulated_paths(paths: List[str]) -> List[str]:
|
| 53 |
+
"""
|
| 54 |
+
Modifies filenames to match the simulated dataset naming convention.
|
| 55 |
+
Original: 'MBD_0001_S2_...' -> Simulated: 'MBD_toarefl_S2_...'
|
| 56 |
+
"""
|
| 57 |
+
simulated_paths = []
|
| 58 |
+
for path in paths:
|
| 59 |
+
try:
|
| 60 |
+
tokens = path.split('_')
|
| 61 |
+
# Reconstruct filename based on notebook logic
|
| 62 |
+
if len(tokens) >= 5:
|
| 63 |
+
# e.g., MBD_toarefl_S2_123_456
|
| 64 |
+
simulated_path = f"{tokens[0]}_toarefl_{tokens[3]}_{tokens[4]}"
|
| 65 |
+
simulated_paths.append(simulated_path)
|
| 66 |
+
else:
|
| 67 |
+
simulated_paths.append(path)
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logger.warning(f"Could not parse path {path}: {e}")
|
| 70 |
+
simulated_paths.append(path)
|
| 71 |
+
return simulated_paths
|
| 72 |
+
|
| 73 |
+
# --- Inference Class ---
|
| 74 |
+
|
| 75 |
+
class SimulatedInference:
|
| 76 |
+
def __init__(self, args: argparse.Namespace):
|
| 77 |
+
self.args = args
|
| 78 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 79 |
+
self.output_dir = Path(args.output_dir)
|
| 80 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 81 |
+
|
| 82 |
+
logger.info(f"Initializing Inference on device: {self.device}")
|
| 83 |
+
|
| 84 |
+
self.model = self._init_model()
|
| 85 |
+
self._load_checkpoint(args.checkpoint)
|
| 86 |
+
|
| 87 |
+
def _init_model(self) -> nn.Module:
|
| 88 |
+
model_args = dict(
|
| 89 |
+
backbone="terramind_v1_base",
|
| 90 |
+
backbone_pretrained=False,
|
| 91 |
+
backbone_modalities=["S2L2A"],
|
| 92 |
+
backbone_merge_method="mean",
|
| 93 |
+
decoder="UperNetDecoder",
|
| 94 |
+
decoder_scale_modules=True,
|
| 95 |
+
decoder_channels=256,
|
| 96 |
+
num_classes=2,
|
| 97 |
+
head_dropout=0.3,
|
| 98 |
+
necks=[
|
| 99 |
+
{"name": "ReshapeTokensToImage", "remove_cls_token": False},
|
| 100 |
+
{"name": "SelectIndices", "indices": [2, 5, 8, 11]},
|
| 101 |
+
{"name": "LearnedInterpolateToPyramidal"},
|
| 102 |
+
],
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
task = ClassificationTask(
|
| 106 |
+
model_args=model_args,
|
| 107 |
+
model_factory="EncoderDecoderFactory",
|
| 108 |
+
loss="ce",
|
| 109 |
+
ignore_index=-1
|
| 110 |
+
)
|
| 111 |
+
task.configure_models()
|
| 112 |
+
return task.model.to(self.device)
|
| 113 |
+
|
| 114 |
+
def _load_checkpoint(self, checkpoint_path: str):
|
| 115 |
+
path = Path(checkpoint_path)
|
| 116 |
+
if not path.exists():
|
| 117 |
+
raise FileNotFoundError(f"Checkpoint not found at {path}")
|
| 118 |
+
|
| 119 |
+
logger.info(f"Loading weights from {path}...")
|
| 120 |
+
checkpoint = torch.load(path, map_location=self.device)
|
| 121 |
+
|
| 122 |
+
if 'state_dict' in checkpoint:
|
| 123 |
+
state_dict = checkpoint['state_dict']
|
| 124 |
+
else:
|
| 125 |
+
state_dict = checkpoint
|
| 126 |
+
|
| 127 |
+
self.model.load_state_dict(state_dict, strict=False)
|
| 128 |
+
self.model.eval()
|
| 129 |
+
|
| 130 |
+
def run_inference(self, dataloader: DataLoader, sample_names: List[str]):
|
| 131 |
+
"""
|
| 132 |
+
Generates predictions and matches them with provided sample identifiers.
|
| 133 |
+
"""
|
| 134 |
+
results = {}
|
| 135 |
+
|
| 136 |
+
logger.info(f"Starting inference on {len(sample_names)} samples...")
|
| 137 |
+
|
| 138 |
+
# Iterator to match predictions with original filenames
|
| 139 |
+
name_iter = iter(sample_names)
|
| 140 |
+
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
for batch in tqdm(dataloader, desc="Predicting"):
|
| 143 |
+
inputs = batch['S2L2A'].to(self.device)
|
| 144 |
+
|
| 145 |
+
# Forward Pass
|
| 146 |
+
outputs = self.model(x={"S2L2A": inputs})
|
| 147 |
+
probabilities = torch.softmax(outputs.output, dim=1)
|
| 148 |
+
|
| 149 |
+
# Get binary prediction (0 or 1)
|
| 150 |
+
predictions = torch.argmax(probabilities, dim=1)
|
| 151 |
+
batch_preds = predictions.cpu().numpy()
|
| 152 |
+
|
| 153 |
+
# Assign to Sample Names
|
| 154 |
+
for pred in batch_preds:
|
| 155 |
+
try:
|
| 156 |
+
sample_name = next(name_iter)
|
| 157 |
+
results[sample_name] = int(pred)
|
| 158 |
+
except StopIteration:
|
| 159 |
+
logger.error("More predictions than sample names! Check sync.")
|
| 160 |
+
break
|
| 161 |
+
|
| 162 |
+
if len(results) != len(sample_names):
|
| 163 |
+
logger.warning(f"Mismatch: Expected {len(sample_names)} results, got {len(results)}.")
|
| 164 |
+
|
| 165 |
+
# Save CSV
|
| 166 |
+
self._save_results(results)
|
| 167 |
+
|
| 168 |
+
def _save_results(self, results: Dict[str, int]):
|
| 169 |
+
csv_path = self.output_dir / "simulated_predictions.csv"
|
| 170 |
+
with open(csv_path, mode='w', newline='') as f:
|
| 171 |
+
writer = csv.writer(f)
|
| 172 |
+
writer.writerow(['Sample_ID', 'Prediction'])
|
| 173 |
+
for sample, pred in results.items():
|
| 174 |
+
writer.writerow([sample, pred])
|
| 175 |
+
logger.info(f"Predictions saved to {csv_path}")
|
| 176 |
+
|
| 177 |
+
# --- Data Loading ---
|
| 178 |
+
|
| 179 |
+
def get_dataloader_and_names(args) -> Tuple[DataLoader, List[str]]:
|
| 180 |
+
# 1. Read Excel to get base filenames
|
| 181 |
+
try:
|
| 182 |
+
df = pd.read_excel(args.excel_file)
|
| 183 |
+
# If specific folds are requested, filter them
|
| 184 |
+
if args.folds:
|
| 185 |
+
folds_to_use = [int(f) for f in args.folds.split(',')]
|
| 186 |
+
df = df[df['Fold'].isin(folds_to_use)]
|
| 187 |
+
logger.info(f"Filtered to folds: {folds_to_use}")
|
| 188 |
+
|
| 189 |
+
raw_paths = df['Filename'].tolist()
|
| 190 |
+
logger.info(f"Loaded {len(raw_paths)} paths from Excel.")
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logger.error(f"Error reading Excel: {e}")
|
| 193 |
+
raise
|
| 194 |
+
|
| 195 |
+
# 2. Transform paths to Simulated format
|
| 196 |
+
simulated_paths = get_simulated_paths(raw_paths)
|
| 197 |
+
|
| 198 |
+
# 3. Initialize DataModule
|
| 199 |
+
datamodule = MethaneSimulatedDataModule(
|
| 200 |
+
data_root=args.root_dir,
|
| 201 |
+
excel_file=args.excel_file,
|
| 202 |
+
batch_size=args.batch_size,
|
| 203 |
+
paths=simulated_paths,
|
| 204 |
+
train_transform=None,
|
| 205 |
+
val_transform=get_inference_transforms(),
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# 4. Setup
|
| 209 |
+
datamodule.paths = simulated_paths
|
| 210 |
+
datamodule.setup(stage="test")
|
| 211 |
+
|
| 212 |
+
# Try getting test_dataloader, else train/val
|
| 213 |
+
loader = datamodule.test_dataloader() if hasattr(datamodule, 'test_dataloader') else datamodule.train_dataloader()
|
| 214 |
+
|
| 215 |
+
return loader, simulated_paths
|
| 216 |
+
|
| 217 |
+
# --- Main Execution ---
|
| 218 |
+
|
| 219 |
+
def parse_args():
|
| 220 |
+
parser = argparse.ArgumentParser(description="Methane Simulated S2 Inference")
|
| 221 |
+
|
| 222 |
+
parser.add_argument('--root_dir', type=str, required=True, help='Root directory for simulated data')
|
| 223 |
+
parser.add_argument('--excel_file', type=str, required=True, help='Path to Summary Excel file')
|
| 224 |
+
parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint (.pth)')
|
| 225 |
+
parser.add_argument('--output_dir', type=str, default='./inference_results', help='Directory to save results')
|
| 226 |
+
parser.add_argument('--folds', type=str, default=None, help='Comma-separated list of folds to infer on (e.g., "4" or "1,2"). If None, uses all.')
|
| 227 |
+
parser.add_argument('--batch_size', type=int, default=1, help='Inference batch size')
|
| 228 |
+
parser.add_argument('--seed', type=int, default=42)
|
| 229 |
+
|
| 230 |
+
return parser.parse_args()
|
| 231 |
+
|
| 232 |
+
if __name__ == "__main__":
|
| 233 |
+
args = parse_args()
|
| 234 |
+
set_seed(args.seed)
|
| 235 |
+
|
| 236 |
+
# 1. Prepare Data & Names
|
| 237 |
+
dataloader, sample_names = get_dataloader_and_names(args)
|
| 238 |
+
|
| 239 |
+
# 2. Run Inference
|
| 240 |
+
engine = SimulatedInference(args)
|
| 241 |
+
engine.run_inference(dataloader, sample_names)
|
sentinel2_classification_finetuning/script/methane_simulated_datamodule.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import rasterio
|
| 3 |
+
import torch
|
| 4 |
+
from torchgeo.datasets import NonGeoDataset
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torchgeo.datamodules import NonGeoDataModule
|
| 7 |
+
from methane_simulated_dataset import MethaneSimulatedDataset
|
| 8 |
+
# from methane_classification_dataset import MethaneClassificationDataset
|
| 9 |
+
|
| 10 |
+
class MethaneSimulatedDataModule(NonGeoDataModule):
|
| 11 |
+
"""
|
| 12 |
+
A DataModule for handling MethaneClassificationDataset
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
data_root: str,
|
| 18 |
+
excel_file: str,
|
| 19 |
+
paths: list,
|
| 20 |
+
batch_size: int = 8,
|
| 21 |
+
num_workers: int = 0,
|
| 22 |
+
train_transform: callable = None,
|
| 23 |
+
val_transform: callable = None,
|
| 24 |
+
test_transform: callable = None,
|
| 25 |
+
**kwargs
|
| 26 |
+
):
|
| 27 |
+
super().__init__(MethaneSimulatedDataset, batch_size, num_workers, **kwargs)
|
| 28 |
+
|
| 29 |
+
self.data_root = data_root
|
| 30 |
+
self.excel_file = excel_file
|
| 31 |
+
self.paths = paths
|
| 32 |
+
self.train_transform = train_transform
|
| 33 |
+
self.val_transform = val_transform
|
| 34 |
+
self.test_transform = test_transform
|
| 35 |
+
|
| 36 |
+
def setup(self, stage: str = None):
|
| 37 |
+
if stage in ("fit", "train"):
|
| 38 |
+
self.train_dataset = MethaneSimulatedDataset(
|
| 39 |
+
root_dir=self.data_root,
|
| 40 |
+
excel_file=self.excel_file,
|
| 41 |
+
paths=self.paths,
|
| 42 |
+
transform=self.train_transform,
|
| 43 |
+
)
|
| 44 |
+
if stage in ("fit", "validate", "val"):
|
| 45 |
+
self.val_dataset = MethaneSimulatedDataset(
|
| 46 |
+
root_dir=self.data_root,
|
| 47 |
+
excel_file=self.excel_file,
|
| 48 |
+
paths=self.paths,
|
| 49 |
+
transform=self.val_transform,
|
| 50 |
+
)
|
| 51 |
+
if stage in ("test", "predict"):
|
| 52 |
+
self.test_dataset = MethaneSimulatedDataset(
|
| 53 |
+
root_dir=self.data_root,
|
| 54 |
+
excel_file=self.excel_file,
|
| 55 |
+
paths=self.paths,
|
| 56 |
+
transform=self.test_transform,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def train_dataloader(self):
|
| 60 |
+
return DataLoader(
|
| 61 |
+
self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def val_dataloader(self):
|
| 65 |
+
return DataLoader(
|
| 66 |
+
self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def test_dataloader(self):
|
| 70 |
+
return DataLoader(
|
| 71 |
+
self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
|
| 72 |
+
)
|
sentinel2_classification_finetuning/script/methane_simulated_dataset.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import rasterio
|
| 3 |
+
import torch
|
| 4 |
+
from torchgeo.datasets import NonGeoDataset
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
def min_max_normalize(data, new_min=0, new_max=1):
|
| 11 |
+
data = np.array(data, dtype=np.float32) # Convert to NumPy array
|
| 12 |
+
|
| 13 |
+
# Handle NaN, Inf values
|
| 14 |
+
data = np.nan_to_num(data, nan=np.nanmin(data), posinf=np.max(data), neginf=np.min(data))
|
| 15 |
+
|
| 16 |
+
old_min, old_max = np.min(data), np.max(data)
|
| 17 |
+
|
| 18 |
+
if old_max == old_min: # Prevent division by zero
|
| 19 |
+
return np.full_like(data, new_min, dtype=np.float32) # Uniform array
|
| 20 |
+
|
| 21 |
+
return (data - old_min) / (old_max - old_min + 1e-10) * (new_max - new_min) + new_min
|
| 22 |
+
|
| 23 |
+
class MethaneSimulatedDataset(NonGeoDataset):
|
| 24 |
+
def __init__(self, root_dir, excel_file, paths, transform=None):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.root_dir = root_dir
|
| 27 |
+
self.transform = transform
|
| 28 |
+
self.data_paths = []
|
| 29 |
+
|
| 30 |
+
# Collect paths for labelbinary.tif and sCube.tif in selected folders
|
| 31 |
+
for folder_name in paths:
|
| 32 |
+
subdir_path = os.path.join(root_dir, folder_name)
|
| 33 |
+
if os.path.isdir(subdir_path):
|
| 34 |
+
label_path = os.path.join(subdir_path, folder_name + '_mask.tif')
|
| 35 |
+
scube_path = os.path.join(subdir_path, folder_name + '_hsi.dat')
|
| 36 |
+
|
| 37 |
+
if os.path.exists(label_path) and os.path.exists(scube_path):
|
| 38 |
+
self.data_paths.append((label_path, scube_path))
|
| 39 |
+
|
| 40 |
+
def __len__(self):
|
| 41 |
+
return len(self.data_paths)
|
| 42 |
+
|
| 43 |
+
def __getitem__(self, idx):
|
| 44 |
+
label_path, scube_path = self.data_paths[idx]
|
| 45 |
+
|
| 46 |
+
# Load the label image (single band)
|
| 47 |
+
with rasterio.open(label_path) as label_src:
|
| 48 |
+
label_image = label_src.read() # Shape: [512, 512]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Load the sCube image (multi-band), drop the first band
|
| 52 |
+
with rasterio.open(scube_path) as scube_src:
|
| 53 |
+
scube_image = scube_src.read() # Shape: [13, 512, 512]
|
| 54 |
+
|
| 55 |
+
# Read only the first 12 bands for testing purposes
|
| 56 |
+
# Map the bands later on
|
| 57 |
+
scube_image = scube_image[:12, :, :]
|
| 58 |
+
|
| 59 |
+
# Convert to PyTorch tensors
|
| 60 |
+
scube_tensor = torch.from_numpy(scube_image).float() # Shape: [12, 512, 512]
|
| 61 |
+
label_tensor = torch.from_numpy(label_image).float() # Shape: [512, 512]
|
| 62 |
+
|
| 63 |
+
# Resize to [12, 224, 224] and [224, 224] respectively
|
| 64 |
+
scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
|
| 65 |
+
label_tensor = F.interpolate(label_tensor.unsqueeze(0), size=(224, 224), mode='nearest').squeeze(0)
|
| 66 |
+
|
| 67 |
+
label_tensor = label_tensor.clip(0, 1) # Clip values to [0, 1]
|
| 68 |
+
scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0) # Replace NaNs with 0
|
| 69 |
+
# normalized_tensor = min_max_normalize(scube_tensor)
|
| 70 |
+
# Convert labels to binary
|
| 71 |
+
contains_methane = (label_tensor > 0).any().long()
|
| 72 |
+
|
| 73 |
+
# Convert to one-hot encoding
|
| 74 |
+
one_hot_label = F.one_hot(contains_methane, num_classes=2).float()
|
| 75 |
+
|
| 76 |
+
# Apply transformations (if any)
|
| 77 |
+
if self.transform:
|
| 78 |
+
transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0)))
|
| 79 |
+
scube_tensor = transformed['image'].transpose(2, 0, 1) # Convert back to [C, H, W]
|
| 80 |
+
|
| 81 |
+
return {'S2L2A': scube_tensor, 'label': one_hot_label, 'sample': scube_path}
|
sentinel2_classification_finetuning/script/train_simulated_s2.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
import csv
|
| 4 |
+
import random
|
| 5 |
+
import warnings
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, List, Tuple, Any, Optional
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.optim as optim
|
| 15 |
+
import albumentations as A
|
| 16 |
+
from torch.utils.data import DataLoader
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from sklearn.model_selection import train_test_split
|
| 19 |
+
from sklearn.metrics import (
|
| 20 |
+
accuracy_score, recall_score, f1_score, matthews_corrcoef, confusion_matrix
|
| 21 |
+
)
|
| 22 |
+
from rasterio.errors import NotGeoreferencedWarning
|
| 23 |
+
|
| 24 |
+
# --- CRITICAL IMPORTS ---
|
| 25 |
+
import terramind
|
| 26 |
+
from terratorch.tasks import ClassificationTask
|
| 27 |
+
|
| 28 |
+
# Local Imports
|
| 29 |
+
from methane_simulated_datamodule import MethaneSimulatedDataModule
|
| 30 |
+
|
| 31 |
+
# --- Configuration & Setup ---
|
| 32 |
+
|
| 33 |
+
logging.basicConfig(
|
| 34 |
+
level=logging.INFO,
|
| 35 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 36 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
| 37 |
+
)
|
| 38 |
+
logger = logging.getLogger(__name__)
|
| 39 |
+
|
| 40 |
+
logging.getLogger("rasterio._env").setLevel(logging.ERROR)
|
| 41 |
+
warnings.simplefilter("ignore", NotGeoreferencedWarning)
|
| 42 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 43 |
+
|
| 44 |
+
def set_seed(seed: int = 42):
|
| 45 |
+
random.seed(seed)
|
| 46 |
+
np.random.seed(seed)
|
| 47 |
+
torch.manual_seed(seed)
|
| 48 |
+
if torch.cuda.is_available():
|
| 49 |
+
torch.cuda.manual_seed_all(seed)
|
| 50 |
+
|
| 51 |
+
def get_training_transforms() -> A.Compose:
|
| 52 |
+
return A.Compose([
|
| 53 |
+
A.ElasticTransform(p=0.25),
|
| 54 |
+
A.RandomRotate90(p=0.5),
|
| 55 |
+
A.Flip(p=0.5),
|
| 56 |
+
A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5)
|
| 57 |
+
])
|
| 58 |
+
|
| 59 |
+
# --- Path Utilities ---
|
| 60 |
+
|
| 61 |
+
def get_simulated_paths(paths: List[str], tag: str = "toarefl") -> List[str]:
|
| 62 |
+
"""
|
| 63 |
+
Modifies filenames to match the I1/TOA naming convention.
|
| 64 |
+
Converts 'ang2015..._S2_...' -> 'ang2015..._{tag}_...'
|
| 65 |
+
"""
|
| 66 |
+
simulated_paths = []
|
| 67 |
+
for path in paths:
|
| 68 |
+
try:
|
| 69 |
+
tokens = path.split('_')
|
| 70 |
+
# Logic: {ID}_{tag}_{Coord1}_{Coord2}
|
| 71 |
+
# Adjusts original filename tokens to target format
|
| 72 |
+
if len(tokens) >= 5:
|
| 73 |
+
simulated_path = f"{tokens[0]}_{tag}_{tokens[3]}_{tokens[4]}"
|
| 74 |
+
simulated_paths.append(simulated_path)
|
| 75 |
+
else:
|
| 76 |
+
simulated_paths.append(path)
|
| 77 |
+
except Exception as e:
|
| 78 |
+
logger.warning(f"Could not parse path {path}: {e}")
|
| 79 |
+
simulated_paths.append(path)
|
| 80 |
+
return simulated_paths
|
| 81 |
+
|
| 82 |
+
def get_paths_for_fold(excel_file: str, folds: List[int]) -> List[str]:
|
| 83 |
+
try:
|
| 84 |
+
df = pd.read_excel(excel_file)
|
| 85 |
+
df_filtered = df[df['Fold'].isin(folds)]
|
| 86 |
+
return df_filtered['Filename'].tolist()
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logger.error(f"Error reading Excel file: {e}")
|
| 89 |
+
raise
|
| 90 |
+
|
| 91 |
+
# --- Helper Classes ---
|
| 92 |
+
|
| 93 |
+
class MetricTracker:
|
| 94 |
+
def __init__(self):
|
| 95 |
+
self.reset()
|
| 96 |
+
|
| 97 |
+
def reset(self):
|
| 98 |
+
self.all_targets = []
|
| 99 |
+
self.all_predictions = []
|
| 100 |
+
self.total_loss = 0.0
|
| 101 |
+
self.steps = 0
|
| 102 |
+
|
| 103 |
+
def update(self, loss: float, targets: torch.Tensor, probabilities: torch.Tensor):
|
| 104 |
+
self.total_loss += loss
|
| 105 |
+
self.steps += 1
|
| 106 |
+
self.all_targets.extend(torch.argmax(targets, dim=1).detach().cpu().numpy())
|
| 107 |
+
self.all_predictions.extend(torch.argmax(probabilities, dim=1).detach().cpu().numpy())
|
| 108 |
+
|
| 109 |
+
def compute(self) -> Dict[str, float]:
|
| 110 |
+
if not self.all_targets:
|
| 111 |
+
return {}
|
| 112 |
+
|
| 113 |
+
tn, fp, fn, tp = confusion_matrix(self.all_targets, self.all_predictions, labels=[0, 1]).ravel()
|
| 114 |
+
|
| 115 |
+
return {
|
| 116 |
+
"Loss": self.total_loss / max(self.steps, 1),
|
| 117 |
+
"Accuracy": accuracy_score(self.all_targets, self.all_predictions),
|
| 118 |
+
"Specificity": tn / (tn + fp) if (tn + fp) != 0 else 0.0,
|
| 119 |
+
"Sensitivity": recall_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0),
|
| 120 |
+
"F1": f1_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0),
|
| 121 |
+
"MCC": matthews_corrcoef(self.all_targets, self.all_predictions),
|
| 122 |
+
"TP": int(tp), "TN": int(tn), "FP": int(fp), "FN": int(fn)
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
class TrainerI1:
|
| 126 |
+
def __init__(self, args: argparse.Namespace):
|
| 127 |
+
self.args = args
|
| 128 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 129 |
+
self.save_dir = Path(args.save_dir) / f'fold{args.test_fold}'
|
| 130 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 131 |
+
|
| 132 |
+
self.model = self._init_model()
|
| 133 |
+
self.optimizer, self.scheduler = self._init_optimizer()
|
| 134 |
+
self.criterion = self.task.criterion
|
| 135 |
+
self.best_val_loss = float('inf')
|
| 136 |
+
|
| 137 |
+
logger.info(f"Trainer initialized on device: {self.device}")
|
| 138 |
+
|
| 139 |
+
def _init_model(self) -> nn.Module:
|
| 140 |
+
model_args = dict(
|
| 141 |
+
backbone="terramind_v1_base",
|
| 142 |
+
backbone_pretrained=True,
|
| 143 |
+
backbone_modalities=["S2L2A"],
|
| 144 |
+
backbone_merge_method="mean",
|
| 145 |
+
decoder="UperNetDecoder",
|
| 146 |
+
decoder_scale_modules=True,
|
| 147 |
+
decoder_channels=256,
|
| 148 |
+
num_classes=2,
|
| 149 |
+
head_dropout=0.3,
|
| 150 |
+
necks=[
|
| 151 |
+
{"name": "ReshapeTokensToImage", "remove_cls_token": False},
|
| 152 |
+
{"name": "SelectIndices", "indices": [2, 5, 8, 11]},
|
| 153 |
+
{"name": "LearnedInterpolateToPyramidal"},
|
| 154 |
+
],
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
self.task = ClassificationTask(
|
| 158 |
+
model_args=model_args,
|
| 159 |
+
model_factory="EncoderDecoderFactory",
|
| 160 |
+
loss="ce",
|
| 161 |
+
lr=self.args.lr,
|
| 162 |
+
ignore_index=-1,
|
| 163 |
+
optimizer="AdamW",
|
| 164 |
+
optimizer_hparams={"weight_decay": self.args.weight_decay},
|
| 165 |
+
)
|
| 166 |
+
self.task.configure_models()
|
| 167 |
+
self.task.configure_losses()
|
| 168 |
+
return self.task.model.to(self.device)
|
| 169 |
+
|
| 170 |
+
def _init_optimizer(self):
|
| 171 |
+
optimizer = optim.AdamW(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
|
| 172 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, verbose=True)
|
| 173 |
+
return optimizer, scheduler
|
| 174 |
+
|
| 175 |
+
def run_epoch(self, dataloader: DataLoader, stage: str = "train") -> Dict[str, float]:
|
| 176 |
+
is_train = stage == "train"
|
| 177 |
+
self.model.train() if is_train else self.model.eval()
|
| 178 |
+
tracker = MetricTracker()
|
| 179 |
+
|
| 180 |
+
with torch.set_grad_enabled(is_train):
|
| 181 |
+
pbar = tqdm(dataloader, desc=f" {stage.capitalize()}", leave=False)
|
| 182 |
+
for batch in pbar:
|
| 183 |
+
inputs = batch['S2L2A'].to(self.device)
|
| 184 |
+
targets = batch['label'].to(self.device)
|
| 185 |
+
|
| 186 |
+
outputs = self.model(x={"S2L2A": inputs})
|
| 187 |
+
probabilities = torch.softmax(outputs.output, dim=1)
|
| 188 |
+
loss = self.criterion(probabilities, targets)
|
| 189 |
+
|
| 190 |
+
if is_train:
|
| 191 |
+
self.optimizer.zero_grad()
|
| 192 |
+
loss.backward()
|
| 193 |
+
self.optimizer.step()
|
| 194 |
+
|
| 195 |
+
tracker.update(loss.item(), targets, probabilities)
|
| 196 |
+
pbar.set_postfix(loss=f"{loss.item():.4f}")
|
| 197 |
+
|
| 198 |
+
return tracker.compute()
|
| 199 |
+
|
| 200 |
+
def fit(self, train_loader: DataLoader, val_loader: DataLoader):
|
| 201 |
+
logger.info(f"Starting training for {self.args.epochs} epochs...")
|
| 202 |
+
start_time = time.time()
|
| 203 |
+
|
| 204 |
+
# Initialize CSV logging
|
| 205 |
+
csv_path = self.save_dir / 'train_val_metrics.csv'
|
| 206 |
+
with open(csv_path, 'w', newline='') as f:
|
| 207 |
+
writer = csv.writer(f)
|
| 208 |
+
writer.writerow([
|
| 209 |
+
'Epoch', 'Train_Loss', 'Train_F1', 'Train_Acc',
|
| 210 |
+
'Val_Loss', 'Val_F1', 'Val_Acc', 'Val_Spec', 'Val_Sens'
|
| 211 |
+
])
|
| 212 |
+
|
| 213 |
+
for epoch in range(1, self.args.epochs + 1):
|
| 214 |
+
logger.info(f"Epoch {epoch}/{self.args.epochs}")
|
| 215 |
+
|
| 216 |
+
train_metrics = self.run_epoch(train_loader, stage="train")
|
| 217 |
+
val_metrics = self.run_epoch(val_loader, stage="validate")
|
| 218 |
+
|
| 219 |
+
self.scheduler.step(val_metrics['Loss'])
|
| 220 |
+
|
| 221 |
+
# Log to CSV
|
| 222 |
+
with open(csv_path, 'a', newline='') as f:
|
| 223 |
+
writer = csv.writer(f)
|
| 224 |
+
writer.writerow([
|
| 225 |
+
epoch,
|
| 226 |
+
train_metrics.get('Loss'), train_metrics.get('F1'), train_metrics.get('Accuracy'),
|
| 227 |
+
val_metrics.get('Loss'), val_metrics.get('F1'), val_metrics.get('Accuracy'),
|
| 228 |
+
val_metrics.get('Specificity'), val_metrics.get('Sensitivity')
|
| 229 |
+
])
|
| 230 |
+
|
| 231 |
+
logger.info(f"Train Loss: {train_metrics['Loss']:.4f} | Val Loss: {val_metrics['Loss']:.4f} | Val F1: {val_metrics['F1']:.4f}")
|
| 232 |
+
|
| 233 |
+
# Save Best Model
|
| 234 |
+
if val_metrics['Loss'] < self.best_val_loss:
|
| 235 |
+
self.best_val_loss = val_metrics['Loss']
|
| 236 |
+
torch.save(self.model.state_dict(), self.save_dir / "best_model.pth")
|
| 237 |
+
logger.info(f"--> New best model saved")
|
| 238 |
+
|
| 239 |
+
# Save Final Model
|
| 240 |
+
torch.save(self.model.state_dict(), self.save_dir / "final_model.pth")
|
| 241 |
+
logger.info(f"Training finished in {time.time() - start_time:.2f}s")
|
| 242 |
+
|
| 243 |
+
# --- Data Utilities ---
|
| 244 |
+
|
| 245 |
+
def get_data_loaders(args) -> Tuple[DataLoader, DataLoader]:
|
| 246 |
+
# 1. Determine Folds
|
| 247 |
+
all_folds = list(range(1, args.num_folds + 1))
|
| 248 |
+
train_pool_folds = [f for f in all_folds if f != args.test_fold]
|
| 249 |
+
|
| 250 |
+
# 2. Get Paths
|
| 251 |
+
paths = get_paths_for_fold(args.excel_file, train_pool_folds)
|
| 252 |
+
|
| 253 |
+
# 3. Apply Tag (Dynamic Tagging)
|
| 254 |
+
paths = get_simulated_paths(paths, tag=args.sim_tag)
|
| 255 |
+
|
| 256 |
+
# 4. Train/Val Split (80/20)
|
| 257 |
+
train_paths, val_paths = train_test_split(paths, test_size=0.2, random_state=args.seed)
|
| 258 |
+
|
| 259 |
+
logger.info(f"Data Split - Train: {len(train_paths)}, Val: {len(val_paths)} (Test Fold: {args.test_fold})")
|
| 260 |
+
|
| 261 |
+
# 5. Initialize DataModule
|
| 262 |
+
datamodule = MethaneSimulatedDataModule(
|
| 263 |
+
data_root=args.root_dir,
|
| 264 |
+
excel_file=args.excel_file,
|
| 265 |
+
batch_size=args.batch_size,
|
| 266 |
+
paths=paths, # Initial dummy
|
| 267 |
+
train_transform=get_training_transforms(),
|
| 268 |
+
val_transform=None,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# 6. Create Loaders
|
| 272 |
+
datamodule.paths = train_paths
|
| 273 |
+
datamodule.setup(stage="fit")
|
| 274 |
+
train_loader = datamodule.train_dataloader()
|
| 275 |
+
|
| 276 |
+
datamodule.paths = val_paths
|
| 277 |
+
datamodule.setup(stage="validate")
|
| 278 |
+
val_loader = datamodule.val_dataloader()
|
| 279 |
+
|
| 280 |
+
return train_loader, val_loader
|
| 281 |
+
|
| 282 |
+
# --- Main Execution ---
|
| 283 |
+
|
| 284 |
+
def parse_args():
|
| 285 |
+
parser = argparse.ArgumentParser(description="Methane I1 (Simulated) Training")
|
| 286 |
+
|
| 287 |
+
# Paths
|
| 288 |
+
parser.add_argument('--root_dir', type=str, required=True, help='Root directory for I1/TOA/BOA data')
|
| 289 |
+
parser.add_argument('--excel_file', type=str, required=True, help='Path to Summary Excel')
|
| 290 |
+
parser.add_argument('--save_dir', type=str, default='./checkpoints_i1', help='Output directory')
|
| 291 |
+
|
| 292 |
+
# Simulation Tag Config
|
| 293 |
+
parser.add_argument('--sim_tag', type=str, default='toarefl',
|
| 294 |
+
help='String identifier in filename (e.g. "toarefl" or "boarefl")')
|
| 295 |
+
|
| 296 |
+
# Hyperparameters
|
| 297 |
+
parser.add_argument('--epochs', type=int, default=100)
|
| 298 |
+
parser.add_argument('--batch_size', type=int, default=2, help='Batch size (must be >1 for BatchNorm)')
|
| 299 |
+
parser.add_argument('--lr', type=float, default=1e-5)
|
| 300 |
+
parser.add_argument('--weight_decay', type=float, default=0.05)
|
| 301 |
+
parser.add_argument('--num_folds', type=int, default=5)
|
| 302 |
+
parser.add_argument('--test_fold', type=int, default=4, help='Fold ID to hold out')
|
| 303 |
+
parser.add_argument('--seed', type=int, default=42)
|
| 304 |
+
|
| 305 |
+
return parser.parse_args()
|
| 306 |
+
|
| 307 |
+
if __name__ == "__main__":
|
| 308 |
+
args = parse_args()
|
| 309 |
+
set_seed(args.seed)
|
| 310 |
+
|
| 311 |
+
train_loader, val_loader = get_data_loaders(args)
|
| 312 |
+
|
| 313 |
+
trainer = TrainerI1(args)
|
| 314 |
+
trainer.fit(train_loader, val_loader)
|
urban_inference/methane_urban_datamodule.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import rasterio
|
| 3 |
+
import torch
|
| 4 |
+
from torchgeo.datasets import NonGeoDataset
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torchgeo.datamodules import NonGeoDataModule
|
| 7 |
+
from methane_urban_dataset import MethaneUrbanDataset
|
| 8 |
+
|
| 9 |
+
class MethaneUrbanDataModule(NonGeoDataModule):
|
| 10 |
+
"""
|
| 11 |
+
A DataModule for handling MethaneClassificationDataset
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
data_root: str,
|
| 17 |
+
excel_file: str,
|
| 18 |
+
paths: list,
|
| 19 |
+
batch_size: int = 8,
|
| 20 |
+
num_workers: int = 0,
|
| 21 |
+
train_transform: callable = None,
|
| 22 |
+
val_transform: callable = None,
|
| 23 |
+
test_transform: callable = None,
|
| 24 |
+
**kwargs
|
| 25 |
+
):
|
| 26 |
+
super().__init__(MethaneUrbanDataset, batch_size, num_workers, **kwargs)
|
| 27 |
+
|
| 28 |
+
self.data_root = data_root
|
| 29 |
+
self.excel_file = excel_file
|
| 30 |
+
self.paths = paths
|
| 31 |
+
self.train_transform = train_transform
|
| 32 |
+
self.val_transform = val_transform
|
| 33 |
+
self.test_transform = test_transform
|
| 34 |
+
|
| 35 |
+
def setup(self, stage: str = None):
|
| 36 |
+
if stage in ("fit", "train"):
|
| 37 |
+
self.train_dataset = MethaneUrbanDataset(
|
| 38 |
+
root_dir=self.data_root,
|
| 39 |
+
excel_file=self.excel_file,
|
| 40 |
+
paths=self.paths,
|
| 41 |
+
transform=self.train_transform,
|
| 42 |
+
)
|
| 43 |
+
if stage in ("fit", "validate", "val"):
|
| 44 |
+
self.val_dataset = MethaneUrbanDataset(
|
| 45 |
+
root_dir=self.data_root,
|
| 46 |
+
excel_file=self.excel_file,
|
| 47 |
+
paths=self.paths,
|
| 48 |
+
transform=self.val_transform,
|
| 49 |
+
)
|
| 50 |
+
if stage in ("test", "predict"):
|
| 51 |
+
self.test_dataset = MethaneUrbanDataset(
|
| 52 |
+
root_dir=self.data_root,
|
| 53 |
+
excel_file=self.excel_file,
|
| 54 |
+
paths=self.paths,
|
| 55 |
+
transform=self.test_transform,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def train_dataloader(self):
|
| 59 |
+
return DataLoader(
|
| 60 |
+
self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def val_dataloader(self):
|
| 64 |
+
return DataLoader(
|
| 65 |
+
self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def test_dataloader(self):
|
| 69 |
+
return DataLoader(
|
| 70 |
+
self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True
|
| 71 |
+
)
|
urban_inference/methane_urban_dataset.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import rasterio
|
| 3 |
+
import torch
|
| 4 |
+
from torchgeo.datasets import NonGeoDataset
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
def min_max_normalize(data, new_min=0, new_max=1):
|
| 11 |
+
data = np.array(data, dtype=np.float32) # Convert to NumPy array
|
| 12 |
+
|
| 13 |
+
# Handle NaN, Inf values
|
| 14 |
+
data = np.nan_to_num(data, nan=np.nanmin(data), posinf=np.max(data), neginf=np.min(data))
|
| 15 |
+
|
| 16 |
+
old_min, old_max = np.min(data), np.max(data)
|
| 17 |
+
|
| 18 |
+
if old_max == old_min: # Prevent division by zero
|
| 19 |
+
return np.full_like(data, new_min, dtype=np.float32) # Uniform array
|
| 20 |
+
|
| 21 |
+
return (data - old_min) / (old_max - old_min + 1e-10) * (new_max - new_min) + new_min
|
| 22 |
+
|
| 23 |
+
class MethaneUrbanDataset(NonGeoDataset):
|
| 24 |
+
def __init__(self, root_dir, excel_file, paths, transform=None, mean=None, std=None):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.root_dir = root_dir
|
| 27 |
+
self.transform = transform
|
| 28 |
+
self.data_paths = []
|
| 29 |
+
self.mean = mean if mean else [0.485] * 12 # Default mean if not provided
|
| 30 |
+
self.std = std if std else [0.229] * 12 # Default std if not provided
|
| 31 |
+
|
| 32 |
+
# Collect paths for labelbinary.tif and sCube.tif in selected folders
|
| 33 |
+
for folder_name in paths:
|
| 34 |
+
subdir_path = next((os.path.join(root_dir, d) for d in os.listdir(root_dir) if d.startswith(folder_name) and os.path.isdir(os.path.join(root_dir, d))), None)
|
| 35 |
+
if subdir_path is not None and os.path.isdir(subdir_path):
|
| 36 |
+
label_path = os.path.join(subdir_path, 'hsi')
|
| 37 |
+
scube_path = os.path.join(subdir_path, 'hsi')
|
| 38 |
+
# print(scube_path)
|
| 39 |
+
if os.path.exists(scube_path):
|
| 40 |
+
self.data_paths.append((label_path, scube_path))
|
| 41 |
+
|
| 42 |
+
def __len__(self):
|
| 43 |
+
return len(self.data_paths)
|
| 44 |
+
|
| 45 |
+
def __getitem__(self, idx):
|
| 46 |
+
label_path, scube_path = self.data_paths[idx]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Load the sCube image (multi-band), drop the first band
|
| 50 |
+
with rasterio.open(scube_path) as scube_src:
|
| 51 |
+
scube_image = scube_src.read() # Shape: [13, 512, 512]
|
| 52 |
+
scube_image = scube_image[[0,1,2,3,4,5,6,7,8,9,11,12], :, :] # Drop first band → Shape: [12, 512, 512]
|
| 53 |
+
# scube_image = np.zeros((12, scube_src.height, scube_src.width), dtype=np.float32) # Initialize with zeros
|
| 54 |
+
# selected_bands = [2, 3, 4, 5, 6, 7, 8] # Bands to read
|
| 55 |
+
# for i, band in enumerate(selected_bands):
|
| 56 |
+
# scube_image[i, :, :] = scube_src.read(band + 1)
|
| 57 |
+
# print(scube_path, scube_image.shape)
|
| 58 |
+
|
| 59 |
+
# Convert to PyTorch tensors
|
| 60 |
+
scube_tensor = torch.from_numpy(scube_image).float() # Shape: [12, 512, 512]
|
| 61 |
+
|
| 62 |
+
# Resize to [12, 224, 224] and [224, 224] respectively
|
| 63 |
+
scube_tensor = F.interpolate(scube_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
|
| 64 |
+
|
| 65 |
+
scube_tensor = torch.nan_to_num(scube_tensor, nan=0.0) # Replace NaNs with 0
|
| 66 |
+
|
| 67 |
+
# scube_tensor = (scube_tensor - scube_tensor.mean(dim=(1, 2), keepdim=True)) / (scube_tensor.std(dim=(1, 2), keepdim=True) + 1e-10)
|
| 68 |
+
|
| 69 |
+
# Convert labels to binary
|
| 70 |
+
|
| 71 |
+
contains_methane = torch.zeros(1).long()
|
| 72 |
+
|
| 73 |
+
# Convert to one-hot encoding
|
| 74 |
+
# one_hot_label = F.one_hot(contains_methane, num_classes=2).float()
|
| 75 |
+
|
| 76 |
+
# Apply transformations (if any)
|
| 77 |
+
if self.transform:
|
| 78 |
+
transformed = self.transform(image=np.array(scube_tensor.permute(1, 2, 0)))
|
| 79 |
+
scube_tensor = transformed['image'].transpose(2, 0, 1) # Convert back to [C, H, W]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
return {'S2L2A': scube_tensor, 'label': contains_methane, 'sample': scube_path.split('/')[2]}
|
urban_inference/urban_inference.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
import csv
|
| 4 |
+
import random
|
| 5 |
+
import warnings
|
| 6 |
+
import time
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, List, Tuple, Any, Optional
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import albumentations as A
|
| 16 |
+
from torch.utils.data import DataLoader
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from rasterio.errors import NotGeoreferencedWarning
|
| 19 |
+
|
| 20 |
+
# --- CRITICAL IMPORTS ---
|
| 21 |
+
import terramind
|
| 22 |
+
from terratorch.tasks import ClassificationTask
|
| 23 |
+
|
| 24 |
+
# Local Imports
|
| 25 |
+
from methane_urban_datamodule import MethaneUrbanDataModule
|
| 26 |
+
|
| 27 |
+
# --- Configuration & Setup ---
|
| 28 |
+
|
| 29 |
+
logging.basicConfig(
|
| 30 |
+
level=logging.INFO,
|
| 31 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 32 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
| 33 |
+
)
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
logging.getLogger("rasterio._env").setLevel(logging.ERROR)
|
| 37 |
+
warnings.simplefilter("ignore", NotGeoreferencedWarning)
|
| 38 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 39 |
+
|
| 40 |
+
def set_seed(seed: int = 42):
|
| 41 |
+
random.seed(seed)
|
| 42 |
+
np.random.seed(seed)
|
| 43 |
+
torch.manual_seed(seed)
|
| 44 |
+
if torch.cuda.is_available():
|
| 45 |
+
torch.cuda.manual_seed_all(seed)
|
| 46 |
+
|
| 47 |
+
def get_inference_transforms() -> A.Compose:
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
# --- Inference Class ---
|
| 51 |
+
|
| 52 |
+
class UrbanInference:
|
| 53 |
+
def __init__(self, args: argparse.Namespace):
|
| 54 |
+
self.args = args
|
| 55 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 56 |
+
self.output_dir = Path(args.output_dir)
|
| 57 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 58 |
+
|
| 59 |
+
logger.info(f"Initializing Inference on device: {self.device}")
|
| 60 |
+
|
| 61 |
+
self.model = self._init_model()
|
| 62 |
+
self._load_checkpoint(args.checkpoint)
|
| 63 |
+
|
| 64 |
+
def _init_model(self) -> nn.Module:
|
| 65 |
+
model_args = dict(
|
| 66 |
+
backbone="terramind_v1_base",
|
| 67 |
+
backbone_pretrained=False,
|
| 68 |
+
backbone_modalities=["S2L2A"],
|
| 69 |
+
backbone_merge_method="mean",
|
| 70 |
+
decoder="UperNetDecoder",
|
| 71 |
+
decoder_scale_modules=True,
|
| 72 |
+
decoder_channels=256,
|
| 73 |
+
num_classes=2,
|
| 74 |
+
head_dropout=0.3,
|
| 75 |
+
necks=[
|
| 76 |
+
{"name": "ReshapeTokensToImage", "remove_cls_token": False},
|
| 77 |
+
{"name": "SelectIndices", "indices": [2, 5, 8, 11]},
|
| 78 |
+
],
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
task = ClassificationTask(
|
| 82 |
+
model_args=model_args,
|
| 83 |
+
model_factory="EncoderDecoderFactory",
|
| 84 |
+
loss="ce",
|
| 85 |
+
ignore_index=-1
|
| 86 |
+
)
|
| 87 |
+
task.configure_models()
|
| 88 |
+
return task.model.to(self.device)
|
| 89 |
+
|
| 90 |
+
def _load_checkpoint(self, checkpoint_path: str):
|
| 91 |
+
path = Path(checkpoint_path)
|
| 92 |
+
if not path.exists():
|
| 93 |
+
raise FileNotFoundError(f"Checkpoint not found at {path}")
|
| 94 |
+
|
| 95 |
+
logger.info(f"Loading weights from {path}...")
|
| 96 |
+
checkpoint = torch.load(path, map_location=self.device)
|
| 97 |
+
|
| 98 |
+
if 'state_dict' in checkpoint:
|
| 99 |
+
state_dict = checkpoint['state_dict']
|
| 100 |
+
else:
|
| 101 |
+
state_dict = checkpoint
|
| 102 |
+
|
| 103 |
+
self.model.load_state_dict(state_dict, strict=False)
|
| 104 |
+
self.model.eval()
|
| 105 |
+
|
| 106 |
+
def run_inference(self, dataloader: DataLoader, sample_names: List[str]):
|
| 107 |
+
"""
|
| 108 |
+
Generates binary predictions and matches them with provided sample_names (folder names).
|
| 109 |
+
"""
|
| 110 |
+
sample_results = {}
|
| 111 |
+
|
| 112 |
+
logger.info(f"Starting inference on {len(sample_names)} samples...")
|
| 113 |
+
|
| 114 |
+
# Iterator for sample names to match sequential predictions
|
| 115 |
+
name_iter = iter(sample_names)
|
| 116 |
+
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
for batch in tqdm(dataloader, desc="Predicting"):
|
| 119 |
+
inputs = batch['S2L2A'].to(self.device)
|
| 120 |
+
|
| 121 |
+
# Forward Pass
|
| 122 |
+
outputs = self.model(x={"S2L2A": inputs})
|
| 123 |
+
probabilities = torch.softmax(outputs.output, dim=1)
|
| 124 |
+
|
| 125 |
+
# Get binary prediction (0 or 1)
|
| 126 |
+
predictions = torch.argmax(probabilities, dim=1)
|
| 127 |
+
batch_preds = predictions.cpu().numpy()
|
| 128 |
+
|
| 129 |
+
# Assign Directory Names to Predictions
|
| 130 |
+
for pred in batch_preds:
|
| 131 |
+
try:
|
| 132 |
+
dir_name = next(name_iter)
|
| 133 |
+
sample_results[dir_name] = int(pred)
|
| 134 |
+
except StopIteration:
|
| 135 |
+
logger.error("More predictions generated than sample names provided! Check dataloader sync.")
|
| 136 |
+
break
|
| 137 |
+
|
| 138 |
+
# Check if we missed any samples
|
| 139 |
+
if len(sample_results) != len(sample_names):
|
| 140 |
+
logger.warning(f"Mismatch: Expected {len(sample_names)} results, got {len(sample_results)}.")
|
| 141 |
+
|
| 142 |
+
# Save CSV
|
| 143 |
+
self._save_results(sample_results)
|
| 144 |
+
|
| 145 |
+
def _save_results(self, results: Dict[str, int]):
|
| 146 |
+
csv_path = self.output_dir / "inference_predictions.csv"
|
| 147 |
+
with open(csv_path, mode='w', newline='') as f:
|
| 148 |
+
writer = csv.writer(f)
|
| 149 |
+
writer.writerow(['Sample_Directory', 'Prediction'])
|
| 150 |
+
for sample, pred in results.items():
|
| 151 |
+
writer.writerow([sample, pred])
|
| 152 |
+
logger.info(f"Predictions saved to {csv_path}")
|
| 153 |
+
|
| 154 |
+
# --- Data Loading ---
|
| 155 |
+
|
| 156 |
+
def get_dataloader_and_names(args) -> Tuple[DataLoader, List[str]]:
|
| 157 |
+
root_path = Path(args.root_dir)
|
| 158 |
+
if not root_path.exists():
|
| 159 |
+
raise FileNotFoundError(f"Data directory {args.root_dir} not found.")
|
| 160 |
+
|
| 161 |
+
paths = None
|
| 162 |
+
if args.excel_file:
|
| 163 |
+
try:
|
| 164 |
+
df = pd.read_excel(args.excel_file)
|
| 165 |
+
# Ensure we get the directory name prefix if using 'Filename' column
|
| 166 |
+
paths = df['Filename'].apply(lambda x: str(x).split('_')[0]).tolist()
|
| 167 |
+
logger.info(f"Filtered {len(paths)} samples from Excel.")
|
| 168 |
+
except Exception as e:
|
| 169 |
+
logger.error(f"Error reading Excel: {e}")
|
| 170 |
+
raise
|
| 171 |
+
|
| 172 |
+
if paths is None:
|
| 173 |
+
# Fallback to all subdirectories
|
| 174 |
+
# SORTING is crucial here to match the sequential dataloader
|
| 175 |
+
paths = sorted([d.name for d in root_path.iterdir() if d.is_dir()])
|
| 176 |
+
logger.info(f"Found {len(paths)} samples in directory (Sorted).")
|
| 177 |
+
|
| 178 |
+
# Initialize DataModule
|
| 179 |
+
datamodule = MethaneUrbanDataModule(
|
| 180 |
+
data_root=args.root_dir,
|
| 181 |
+
excel_file=None,
|
| 182 |
+
batch_size=args.batch_size,
|
| 183 |
+
paths=paths,
|
| 184 |
+
train_transform=None,
|
| 185 |
+
val_transform=get_inference_transforms(),
|
| 186 |
+
test_transform=get_inference_transforms()
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Setup for test stage
|
| 190 |
+
datamodule.paths = paths
|
| 191 |
+
datamodule.setup(stage="test")
|
| 192 |
+
|
| 193 |
+
# Get loader (prefer test_dataloader)
|
| 194 |
+
loader = datamodule.test_dataloader() if hasattr(datamodule, 'test_dataloader') else datamodule.train_dataloader()
|
| 195 |
+
|
| 196 |
+
return loader, paths
|
| 197 |
+
|
| 198 |
+
# --- Main Execution ---
|
| 199 |
+
|
| 200 |
+
def parse_args():
|
| 201 |
+
parser = argparse.ArgumentParser(description="Methane Urban Inference (Directory Names)")
|
| 202 |
+
|
| 203 |
+
parser.add_argument('--root_dir', type=str, required=True, help='Root directory containing sample folders')
|
| 204 |
+
parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint (.pth)')
|
| 205 |
+
parser.add_argument('--excel_file', type=str, help='Optional Excel file to filter specific samples')
|
| 206 |
+
parser.add_argument('--output_dir', type=str, default='./inference_results', help='Directory to save results')
|
| 207 |
+
parser.add_argument('--batch_size', type=int, default=1, help='Inference batch size')
|
| 208 |
+
parser.add_argument('--seed', type=int, default=42)
|
| 209 |
+
|
| 210 |
+
return parser.parse_args()
|
| 211 |
+
|
| 212 |
+
if __name__ == "__main__":
|
| 213 |
+
args = parse_args()
|
| 214 |
+
set_seed(args.seed)
|
| 215 |
+
|
| 216 |
+
# 1. Prepare Data & Capture Directory Names
|
| 217 |
+
dataloader, sample_names = get_dataloader_and_names(args)
|
| 218 |
+
|
| 219 |
+
# 2. Run Inference
|
| 220 |
+
engine = UrbanInference(args)
|
| 221 |
+
engine.run_inference(dataloader, sample_names)
|