| | import argparse |
| | import logging |
| | import csv |
| | import random |
| | import warnings |
| | import time |
| | from pathlib import Path |
| | from typing import Dict, List, Tuple, Any, Optional |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | import albumentations as A |
| | from torch.utils.data import DataLoader |
| | from tqdm import tqdm |
| | from sklearn.model_selection import train_test_split |
| | from sklearn.metrics import ( |
| | accuracy_score, recall_score, f1_score, matthews_corrcoef, confusion_matrix |
| | ) |
| | from rasterio.errors import NotGeoreferencedWarning |
| | import terramind |
| |
|
| | |
| | from methane_classification_datamodule import MethaneClassificationDataModule |
| |
|
| | |
| | from terratorch.tasks import ClassificationTask |
| |
|
| |
|
| | |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
| | datefmt='%Y-%m-%d %H:%M:%S' |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | logging.getLogger("rasterio._env").setLevel(logging.ERROR) |
| | warnings.simplefilter("ignore", NotGeoreferencedWarning) |
| | warnings.filterwarnings("ignore", category=FutureWarning) |
| |
|
| | def set_seed(seed: int = 42): |
| | """Sets the seed for reproducibility across random, numpy, and torch.""" |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| | def get_training_transforms() -> A.Compose: |
| | """Returns the albumentations training pipeline.""" |
| | return A.Compose([ |
| | A.ElasticTransform(p=0.25), |
| | A.RandomRotate90(p=0.5), |
| | A.Flip(p=0.5), |
| | A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5) |
| | ]) |
| |
|
| | |
| |
|
| | class MetricTracker: |
| | """Accumulates targets and predictions to calculate epoch-level metrics.""" |
| | def __init__(self): |
| | self.reset() |
| |
|
| | def reset(self): |
| | self.all_targets = [] |
| | self.all_predictions = [] |
| | self.total_loss = 0.0 |
| | self.steps = 0 |
| |
|
| | def update(self, loss: float, targets: torch.Tensor, probabilities: torch.Tensor): |
| | self.total_loss += loss |
| | self.steps += 1 |
| | |
| | self.all_targets.extend(torch.argmax(targets, dim=1).detach().cpu().numpy()) |
| | self.all_predictions.extend(torch.argmax(probabilities, dim=1).detach().cpu().numpy()) |
| |
|
| | def compute(self) -> Dict[str, float]: |
| | """Calculates aggregate metrics for the accumulated data.""" |
| | if not self.all_targets: |
| | return {} |
| | |
| | |
| | tn, fp, fn, tp = confusion_matrix(self.all_targets, self.all_predictions, labels=[0, 1]).ravel() |
| | |
| | return { |
| | "Loss": self.total_loss / max(self.steps, 1), |
| | "Accuracy": accuracy_score(self.all_targets, self.all_predictions), |
| | "Specificity": tn / (tn + fp) if (tn + fp) != 0 else 0.0, |
| | "Sensitivity": recall_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0), |
| | "F1": f1_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0), |
| | "MCC": matthews_corrcoef(self.all_targets, self.all_predictions), |
| | } |
| |
|
| | class MethaneTrainer: |
| | """ |
| | Handles the training lifecycle: Model setup, Training loop, Validation, and Checkpointing. |
| | """ |
| | def __init__(self, args: argparse.Namespace): |
| | self.args = args |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | self.save_dir = Path(args.save_dir) / f'fold{args.test_fold}' |
| | self.save_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | self.model = self._init_model() |
| | self.optimizer, self.scheduler = self._init_optimizer() |
| | self.criterion = self.task.criterion |
| | |
| | self.best_val_loss = float('inf') |
| | |
| | logger.info(f"Trainer initialized on device: {self.device}") |
| |
|
| | def _init_model(self) -> nn.Module: |
| | """Initializes the TerraTorch Classification Task and Model.""" |
| | model_config = dict( |
| | backbone="terramind_v1_base", |
| | backbone_pretrained=True, |
| | backbone_modalities=["S2L2A"], |
| | backbone_merge_method="mean", |
| | decoder="UperNetDecoder", |
| | decoder_scale_modules=True, |
| | decoder_channels=256, |
| | num_classes=2, |
| | head_dropout=0.3, |
| | necks=[ |
| | {"name": "ReshapeTokensToImage", "remove_cls_token": False}, |
| | {"name": "SelectIndices", "indices": [2, 5, 8, 11]}, |
| | ], |
| | ) |
| |
|
| | self.task = ClassificationTask( |
| | model_args=model_config, |
| | model_factory="EncoderDecoderFactory", |
| | loss="ce", |
| | lr=self.args.lr, |
| | ignore_index=-1, |
| | optimizer="AdamW", |
| | optimizer_hparams={"weight_decay": self.args.weight_decay}, |
| | ) |
| | self.task.configure_models() |
| | self.task.configure_losses() |
| | return self.task.model.to(self.device) |
| |
|
| | def _init_optimizer(self): |
| | optimizer = optim.AdamW( |
| | self.model.parameters(), |
| | lr=self.args.lr, |
| | weight_decay=self.args.weight_decay |
| | ) |
| | scheduler = optim.lr_scheduler.ReduceLROnPlateau( |
| | optimizer, mode='min', patience=5, verbose=True |
| | ) |
| | return optimizer, scheduler |
| |
|
| | def run_epoch(self, dataloader: DataLoader, stage: str = "train") -> Dict[str, float]: |
| | """Runs a single epoch for either training or validation.""" |
| | is_train = stage == "train" |
| | self.model.train() if is_train else self.model.eval() |
| | |
| | tracker = MetricTracker() |
| | |
| | |
| | with torch.set_grad_enabled(is_train): |
| | pbar = tqdm(dataloader, desc=f" {stage.capitalize()}", leave=False) |
| | |
| | for batch in pbar: |
| | inputs = batch['S2L2A'].to(self.device) |
| | targets = batch['label'].to(self.device) |
| |
|
| | |
| | outputs = self.model(x={"S2L2A": inputs}) |
| | probabilities = torch.softmax(outputs.output, dim=1) |
| | loss = self.criterion(probabilities, targets) |
| |
|
| | if is_train: |
| | self.optimizer.zero_grad() |
| | loss.backward() |
| | self.optimizer.step() |
| |
|
| | |
| | tracker.update(loss.item(), targets, probabilities) |
| | |
| | |
| | pbar.set_postfix(loss=f"{loss.item():.4f}") |
| |
|
| | return tracker.compute() |
| |
|
| | def save_checkpoint(self, filename: str): |
| | path = self.save_dir / filename |
| | torch.save(self.model.state_dict(), path) |
| | logger.info(f"Saved model to {path}") |
| |
|
| | def log_to_csv(self, epoch: int, train_metrics: Dict, val_metrics: Dict): |
| | """Appends metrics to the CSV log file.""" |
| | csv_path = self.save_dir / 'train_val_metrics.csv' |
| | file_exists = csv_path.exists() |
| | |
| | |
| | headers = ['Epoch'] + [f'Train_{k}' for k in train_metrics.keys()] + [f'Val_{k}' for k in val_metrics.keys()] |
| | |
| | with open(csv_path, mode='a', newline='') as f: |
| | writer = csv.writer(f) |
| | if not file_exists: |
| | writer.writerow(headers) |
| | |
| | row = [epoch] + list(train_metrics.values()) + list(val_metrics.values()) |
| | writer.writerow(row) |
| |
|
| | def fit(self, train_loader: DataLoader, val_loader: DataLoader): |
| | """Main training entry point.""" |
| | logger.info(f"Starting training for {self.args.epochs} epochs...") |
| | start_time = time.time() |
| |
|
| | for epoch in range(1, self.args.epochs + 1): |
| | logger.info(f"Epoch {epoch}/{self.args.epochs}") |
| | |
| | |
| | train_metrics = self.run_epoch(train_loader, stage="train") |
| | val_metrics = self.run_epoch(val_loader, stage="validate") |
| | |
| | |
| | self.scheduler.step(val_metrics['Loss']) |
| | |
| | |
| | self.log_to_csv(epoch, train_metrics, val_metrics) |
| | logger.info( |
| | f"Train Loss: {train_metrics['Loss']:.4f} | " |
| | f"Val Loss: {val_metrics['Loss']:.4f} | " |
| | f"Val F1: {val_metrics['F1']:.4f}" |
| | ) |
| |
|
| | |
| | if val_metrics['Loss'] < self.best_val_loss: |
| | self.best_val_loss = val_metrics['Loss'] |
| | self.save_checkpoint("best_model.pth") |
| | logger.info(f"--> New best model (Val Loss: {self.best_val_loss:.4f})") |
| |
|
| | |
| | self.save_checkpoint("final_model.pth") |
| | logger.info(f"Training finished in {time.time() - start_time:.2f}s") |
| |
|
| |
|
| | |
| |
|
| | def get_data_loaders(args) -> Tuple[DataLoader, DataLoader]: |
| | """Prepares DataModule and returns Train/Val loaders.""" |
| | |
| | |
| | try: |
| | df = pd.read_csv(args.excel_file) if args.excel_file.endswith('.csv') else pd.read_excel(args.excel_file) |
| | except Exception as e: |
| | logger.error(f"Failed to load summary file: {e}") |
| | raise |
| |
|
| | |
| | all_folds = range(1, args.num_folds + 1) |
| | train_pool_folds = [f for f in all_folds if f != args.test_fold] |
| | |
| | |
| | df_filtered = df[df['Fold'].isin(train_pool_folds)] |
| | if df_filtered.empty: |
| | raise ValueError(f"No data found for folds {train_pool_folds}. Check 'Fold' column in Excel.") |
| | |
| | paths = df_filtered['Filename'].tolist() |
| | |
| | |
| | train_paths, val_paths = train_test_split(paths, test_size=0.2, random_state=args.seed) |
| | |
| | logger.info(f"Data Split - Train: {len(train_paths)}, Val: {len(val_paths)} (Test Fold: {args.test_fold})") |
| |
|
| | |
| | datamodule = MethaneClassificationDataModule( |
| | data_root=args.root_dir, |
| | excel_file=args.excel_file, |
| | batch_size=args.batch_size, |
| | paths=train_paths, |
| | train_transform=get_training_transforms(), |
| | val_transform=None, |
| | ) |
| | |
| | |
| | datamodule.paths = train_paths |
| | datamodule.setup(stage="fit") |
| | train_loader = datamodule.train_dataloader() |
| | |
| | datamodule.paths = val_paths |
| | datamodule.setup(stage="validate") |
| | val_loader = datamodule.val_dataloader() |
| | |
| | return train_loader, val_loader |
| |
|
| |
|
| | |
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description="Methane Classification Training with TerraTorch") |
| | |
| | |
| | parser.add_argument('--root_dir', type=str, required=True, help='Root directory for satellite images') |
| | parser.add_argument('--excel_file', type=str, required=True, help='Path to summary Excel/CSV file') |
| | parser.add_argument('--save_dir', type=str, default='./checkpoints', help='Directory to save outputs') |
| | |
| | |
| | parser.add_argument('--epochs', type=int, default=100) |
| | parser.add_argument('--batch_size', type=int, default=8) |
| | parser.add_argument('--lr', type=float, default=1e-5) |
| | parser.add_argument('--weight_decay', type=float, default=0.05) |
| | parser.add_argument('--num_folds', type=int, default=5) |
| | parser.add_argument('--test_fold', type=int, default=2, help='Fold ID to hold out for testing') |
| | parser.add_argument('--seed', type=int, default=42) |
| | |
| | return parser.parse_args() |
| |
|
| | if __name__ == "__main__": |
| | args = parse_args() |
| | set_seed(args.seed) |
| | |
| | |
| | train_loader, val_loader = get_data_loaders(args) |
| | |
| | |
| | trainer = MethaneTrainer(args) |
| | trainer.fit(train_loader, val_loader) |