| | import argparse |
| | import logging |
| | import csv |
| | import random |
| | import warnings |
| | import time |
| | import json |
| | from pathlib import Path |
| | from functools import partial |
| | from typing import Dict, List, Tuple, Any, Optional |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | import albumentations as A |
| | from torch.utils.data import DataLoader |
| | from tqdm import tqdm |
| | from sklearn.model_selection import train_test_split |
| | from sklearn.metrics import ( |
| | accuracy_score, recall_score, f1_score, matthews_corrcoef, confusion_matrix |
| | ) |
| | from rasterio.errors import NotGeoreferencedWarning |
| | from sentence_transformers import SentenceTransformer |
| |
|
| | |
| | import terramind |
| | from terratorch.tasks import ClassificationTask |
| | from terratorch.registry import TERRATORCH_BACKBONE_REGISTRY, TERRATORCH_DECODER_REGISTRY |
| | from terramind.models.terramind_register import build_terrammind_vit |
| |
|
| | |
| | from methane_text_datamodule import MethaneTextDataModule |
| |
|
| | |
| |
|
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
| | datefmt='%Y-%m-%d %H:%M:%S' |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| | logging.getLogger("rasterio._env").setLevel(logging.ERROR) |
| | warnings.simplefilter("ignore", NotGeoreferencedWarning) |
| | warnings.filterwarnings("ignore", category=FutureWarning) |
| |
|
| | |
| | PRETRAINED_BANDS = { |
| | 'untok_sen2l2a@224': [ |
| | "COASTAL_AEROSOL", "BLUE", "GREEN", "RED", "RED_EDGE_1", "RED_EDGE_2", |
| | "RED_EDGE_3", "NIR_BROAD", "NIR_NARROW", "WATER_VAPOR", "SWIR_1", "SWIR_2", |
| | ] |
| | } |
| |
|
| | def set_seed(seed: int = 42): |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| | def get_training_transforms() -> A.Compose: |
| | return A.Compose([ |
| | A.ElasticTransform(p=0.25), |
| | A.RandomRotate90(p=0.5), |
| | A.Flip(p=0.5), |
| | A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5) |
| | ]) |
| |
|
| | |
| |
|
| | |
| | try: |
| | EMBB_MODEL = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
| | |
| | |
| | if torch.cuda.is_available(): |
| | EMBB_MODEL = EMBB_MODEL.to("cuda") |
| | except Exception as e: |
| | logger.warning(f"Could not load SentenceTransformer: {e}") |
| | EMBB_MODEL = None |
| |
|
| | class TerraMindWithText(nn.Module): |
| | def __init__(self, terramind_kwargs: dict): |
| | super().__init__() |
| | self.terramind = build_terrammind_vit( |
| | variant='terramind_v1_base', |
| | encoder_depth=12, |
| | dim=768, |
| | num_heads=12, |
| | mlp_ratio=4, |
| | qkv_bias=False, |
| | proj_bias=False, |
| | mlp_bias=False, |
| | norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| | act_layer=nn.SiLU, |
| | gated_mlp=True, |
| | pretrained_bands=PRETRAINED_BANDS, |
| | **terramind_kwargs |
| | ) |
| | self.out_channels = [768] * 12 |
| | |
| |
|
| | def forward(self, x, captions): |
| | vision_features = self.terramind(x) |
| | |
| | |
| | |
| | with torch.no_grad(): |
| | captions_embed = EMBB_MODEL.encode(captions, convert_to_tensor=True, show_progress_bar=False) |
| | |
| | |
| | if len(captions_embed.shape) == 3: |
| | captions_embed = captions_embed.squeeze() |
| | |
| | return vision_features + [captions_embed] |
| |
|
| | @TERRATORCH_BACKBONE_REGISTRY.register |
| | def terramind_v1_base_with_text(**kwargs): |
| | return TerraMindWithText(terramind_kwargs=kwargs) |
| |
|
| | @TERRATORCH_DECODER_REGISTRY.register |
| | class SimpleDecoder(nn.Module): |
| | includes_head = True |
| |
|
| | def __init__(self, input_dim=768, num_classes=2, caption_dim=384): |
| | super().__init__() |
| | |
| | dim = input_dim[0] if isinstance(input_dim, (list, tuple)) else input_dim |
| | |
| | self.image_conv = nn.Sequential( |
| | nn.Conv2d(dim, 512, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(512), |
| | nn.ReLU(inplace=True), |
| | nn.Dropout2d(0.3), |
| | nn.Conv2d(512, 256, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(256), |
| | nn.ReLU(inplace=True), |
| | nn.Dropout2d(0.3) |
| | ) |
| |
|
| | self.caption_mlp = nn.Sequential( |
| | nn.Linear(caption_dim, 512), |
| | nn.ReLU(inplace=True), |
| | nn.Dropout(0.3), |
| | nn.Linear(512, 256), |
| | nn.ReLU(inplace=True), |
| | nn.Dropout(0.3) |
| | ) |
| |
|
| | self.cross_attention = nn.MultiheadAttention( |
| | embed_dim=256, num_heads=8, dropout=0.1, batch_first=True |
| | ) |
| |
|
| | self.fusion_conv = nn.Sequential( |
| | nn.Conv2d(512, 256, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(256), |
| | nn.ReLU(inplace=True), |
| | nn.Dropout2d(0.3), |
| | nn.Conv2d(256, 128, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(128), |
| | nn.ReLU(inplace=True), |
| | nn.Dropout2d(0.3) |
| | ) |
| |
|
| | self.conv_head = nn.Sequential( |
| | nn.Conv2d(128, 64, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(64), |
| | nn.ReLU(inplace=True), |
| | nn.Dropout2d(0.3), |
| | nn.Conv2d(64, 1, kernel_size=1) |
| | ) |
| |
|
| | self.out_channels = 1 |
| |
|
| | def forward(self, features: list[torch.Tensor]) -> torch.Tensor: |
| | |
| | caption_embed = features[-1] |
| | image_features = features[:12] |
| | |
| | |
| | x = torch.stack(image_features, dim=1).mean(dim=1) |
| | |
| | B, N, C = x.shape |
| | H = W = int(N ** 0.5) |
| | |
| | x = x.permute(0, 2, 1).view(B, C, H, W) |
| | img_features = self.image_conv(x) |
| | |
| | |
| | if caption_embed.dim() == 1: |
| | caption_embed = caption_embed.unsqueeze(0) |
| | |
| | caption_features = self.caption_mlp(caption_embed) |
| | |
| | |
| | caption_spatial = caption_features.unsqueeze(-1).unsqueeze(-1) |
| | caption_spatial = caption_spatial.expand(B, -1, H, W) |
| | |
| | |
| | fused_features = torch.cat([img_features, caption_spatial], dim=1) |
| | fused = self.fusion_conv(fused_features) |
| | |
| | output = self.conv_head(fused) |
| | return output |
| |
|
| | |
| |
|
| | class MetricTracker: |
| | def __init__(self): |
| | self.reset() |
| |
|
| | def reset(self): |
| | self.all_targets = [] |
| | self.all_predictions = [] |
| | self.total_loss = 0.0 |
| | self.steps = 0 |
| |
|
| | def update(self, loss: float, targets: torch.Tensor, probabilities: torch.Tensor): |
| | self.total_loss += loss |
| | self.steps += 1 |
| | self.all_targets.extend(torch.argmax(targets, dim=1).detach().cpu().numpy()) |
| | self.all_predictions.extend(torch.argmax(probabilities, dim=1).detach().cpu().numpy()) |
| |
|
| | def compute(self) -> Dict[str, float]: |
| | if not self.all_targets: |
| | return {} |
| | |
| | tn, fp, fn, tp = confusion_matrix(self.all_targets, self.all_predictions, labels=[0, 1]).ravel() |
| | |
| | return { |
| | "Loss": self.total_loss / max(self.steps, 1), |
| | "Accuracy": accuracy_score(self.all_targets, self.all_predictions), |
| | "Specificity": tn / (tn + fp) if (tn + fp) != 0 else 0.0, |
| | "Sensitivity": recall_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0), |
| | "F1": f1_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0), |
| | "MCC": matthews_corrcoef(self.all_targets, self.all_predictions), |
| | } |
| |
|
| | class MethaneTextTrainer: |
| | def __init__(self, args: argparse.Namespace): |
| | self.args = args |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | self.save_dir = Path(args.save_dir) / f'fold{args.test_fold}' |
| | self.save_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | self.model = self._init_model() |
| | self.optimizer, self.scheduler = self._init_optimizer() |
| | self.criterion = self.task.criterion |
| | self.best_val_loss = float('inf') |
| | |
| | logger.info(f"Trainer initialized on device: {self.device}") |
| |
|
| | def _init_model(self) -> nn.Module: |
| | model_args = dict( |
| | backbone="terramind_v1_base_with_text", |
| | backbone_pretrained=True, |
| | backbone_modalities=["S2L2A"], |
| | backbone_merge_method="mean", |
| | num_classes=2, |
| | head_dropout=0.3, |
| | decoder="SimpleDecoder", |
| | ) |
| |
|
| | self.task = ClassificationTask( |
| | model_args=model_args, |
| | model_factory="EncoderDecoderFactory", |
| | loss="ce", |
| | lr=self.args.lr, |
| | ignore_index=-1, |
| | optimizer="AdamW", |
| | optimizer_hparams={"weight_decay": self.args.weight_decay}, |
| | ) |
| | self.task.configure_models() |
| | self.task.configure_losses() |
| | return self.task.model.to(self.device) |
| |
|
| | def _init_optimizer(self): |
| | optimizer = optim.AdamW(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) |
| | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, verbose=True) |
| | return optimizer, scheduler |
| |
|
| | def run_epoch(self, dataloader: DataLoader, stage: str = "train") -> Dict[str, float]: |
| | is_train = stage == "train" |
| | self.model.train() if is_train else self.model.eval() |
| | tracker = MetricTracker() |
| | |
| | with torch.set_grad_enabled(is_train): |
| | pbar = tqdm(dataloader, desc=f" {stage.capitalize()}", leave=False) |
| | for batch in pbar: |
| | |
| | inputs = batch['S2L2A'].to(self.device) |
| | captions = batch['caption'] |
| | targets = batch['label'].to(self.device) |
| |
|
| | |
| | |
| | outputs = self.model(x={"S2L2A": inputs}, captions=captions) |
| | probabilities = torch.softmax(outputs.output, dim=1) |
| | loss = self.criterion(probabilities, targets) |
| |
|
| | if is_train: |
| | self.optimizer.zero_grad() |
| | loss.backward() |
| | self.optimizer.step() |
| |
|
| | tracker.update(loss.item(), targets, probabilities) |
| | pbar.set_postfix(loss=f"{loss.item():.4f}") |
| |
|
| | return tracker.compute() |
| |
|
| | def log_to_csv(self, epoch: int, train_metrics: Dict, val_metrics: Dict): |
| | csv_path = self.save_dir / 'train_val_metrics.csv' |
| | headers = ['Epoch'] + [f'Train_{k}' for k in train_metrics.keys()] + [f'Val_{k}' for k in val_metrics.keys()] |
| | |
| | with open(csv_path, mode='a', newline='') as f: |
| | writer = csv.writer(f) |
| | if not csv_path.exists(): |
| | writer.writerow(headers) |
| | writer.writerow([epoch] + list(train_metrics.values()) + list(val_metrics.values())) |
| |
|
| | def fit(self, train_loader: DataLoader, val_loader: DataLoader): |
| | logger.info(f"Starting training for {self.args.epochs} epochs...") |
| | start_time = time.time() |
| |
|
| | for epoch in range(1, self.args.epochs + 1): |
| | logger.info(f"Epoch {epoch}/{self.args.epochs}") |
| | |
| | train_metrics = self.run_epoch(train_loader, stage="train") |
| | val_metrics = self.run_epoch(val_loader, stage="validate") |
| | |
| | self.scheduler.step(val_metrics['Loss']) |
| | self.log_to_csv(epoch, train_metrics, val_metrics) |
| | |
| | logger.info(f"Train Loss: {train_metrics['Loss']:.4f} | Val Loss: {val_metrics['Loss']:.4f} | Val F1: {val_metrics['F1']:.4f}") |
| |
|
| | if val_metrics['Loss'] < self.best_val_loss: |
| | self.best_val_loss = val_metrics['Loss'] |
| | torch.save(self.model.state_dict(), self.save_dir / "best_model.pth") |
| | logger.info(f"--> New best model saved") |
| |
|
| | torch.save(self.model.state_dict(), self.save_dir / "final_model.pth") |
| | logger.info(f"Training finished in {time.time() - start_time:.2f}s") |
| |
|
| | |
| |
|
| | def read_captions(json_path: Path, captions_dict: Dict) -> Dict: |
| | """Reads captions from JSON and populates dictionary.""" |
| | if not json_path.exists(): |
| | logger.warning(f"Caption file not found: {json_path}") |
| | return captions_dict |
| | |
| | try: |
| | with open(json_path, "r", encoding="utf-8") as file: |
| | data = json.load(file) |
| |
|
| | for file_path_str, text_list in data.items(): |
| | if text_list and isinstance(text_list, list) and text_list[0]: |
| | text_content = text_list[0][0] |
| | caption_start = text_content.find("CAPTION:") |
| | if caption_start != -1: |
| | caption = text_content[caption_start + len("CAPTION:"):].strip() |
| | |
| | |
| | path_parts = file_path_str.replace("\\", "/").split("/") |
| | if len(path_parts) >= 2: |
| | last_directory = path_parts[-2] |
| | captions_dict[last_directory] = caption |
| | except Exception as e: |
| | logger.error(f"Error reading captions {json_path}: {e}") |
| | |
| | return captions_dict |
| |
|
| | def get_paths_for_fold(excel_file: str, folds: List[int]) -> List[str]: |
| | df = pd.read_excel(excel_file) |
| | df_filtered = df[df['Fold'].isin(folds)] |
| | return df_filtered['Filename'].tolist() |
| |
|
| | def get_data_loaders(args) -> Tuple[DataLoader, DataLoader]: |
| | |
| | captions_dict = {} |
| | captions_dict = read_captions(Path(args.methane_captions), captions_dict) |
| | captions_dict = read_captions(Path(args.no_methane_captions), captions_dict) |
| | logger.info(f"Loaded {len(captions_dict)} captions.") |
| |
|
| | |
| | all_folds = range(1, args.num_folds + 1) |
| | train_pool_folds = [f for f in all_folds if f != args.test_fold] |
| | paths = get_paths_for_fold(args.excel_file, train_pool_folds) |
| | |
| | |
| | train_paths, val_paths = train_test_split(paths, test_size=0.2, random_state=args.seed) |
| | logger.info(f"Train: {len(train_paths)}, Val: {len(val_paths)}") |
| |
|
| | |
| | datamodule = MethaneTextDataModule( |
| | data_root=args.root_dir, |
| | paths=paths, |
| | captions=captions_dict, |
| | train_transform=get_training_transforms(), |
| | batch_size=args.batch_size, |
| | ) |
| | |
| | |
| | datamodule.paths = train_paths |
| | datamodule.setup(stage="train") |
| | train_loader = datamodule.train_dataloader() |
| | |
| | |
| | datamodule.paths = val_paths |
| | datamodule.setup(stage="validate") |
| | val_loader = datamodule.val_dataloader() |
| | |
| | return train_loader, val_loader |
| |
|
| | |
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description="Methane Text-Multimodal Training") |
| | |
| | |
| | parser.add_argument('--root_dir', type=str, required=True, help='Root directory for images') |
| | parser.add_argument('--excel_file', type=str, required=True, help='Path to Summary Excel') |
| | parser.add_argument('--methane_captions', type=str, required=True, help='Path to Methane JSON captions') |
| | parser.add_argument('--no_methane_captions', type=str, required=True, help='Path to No-Methane JSON captions') |
| | parser.add_argument('--save_dir', type=str, default='./checkpoints', help='Output directory') |
| | |
| | |
| | parser.add_argument('--epochs', type=int, default=100) |
| | parser.add_argument('--batch_size', type=int, default=4) |
| | parser.add_argument('--lr', type=float, default=5e-5) |
| | parser.add_argument('--weight_decay', type=float, default=0.05) |
| | parser.add_argument('--num_folds', type=int, default=5) |
| | parser.add_argument('--test_fold', type=int, default=2) |
| | parser.add_argument('--seed', type=int, default=42) |
| | |
| | return parser.parse_args() |
| |
|
| | if __name__ == "__main__": |
| | args = parse_args() |
| | set_seed(args.seed) |
| | |
| | train_loader, val_loader = get_data_loaders(args) |
| | |
| | trainer = MethaneTextTrainer(args) |
| | trainer.fit(train_loader, val_loader) |