import argparse import logging import csv import random import warnings import time from pathlib import Path from typing import Dict, List, Tuple, Any, Optional import numpy as np import pandas as pd import torch import torch.nn as nn import torch.optim as optim import albumentations as A from torch.utils.data import DataLoader from tqdm import tqdm from sklearn.model_selection import train_test_split from sklearn.metrics import ( accuracy_score, recall_score, f1_score, matthews_corrcoef, confusion_matrix ) from rasterio.errors import NotGeoreferencedWarning import terramind # Local Imports from methane_classification_datamodule import MethaneClassificationDataModule # TerraTorch Imports from terratorch.tasks import ClassificationTask # --- Configuration & Setup --- # Configure Logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger(__name__) # Suppress Warnings logging.getLogger("rasterio._env").setLevel(logging.ERROR) warnings.simplefilter("ignore", NotGeoreferencedWarning) warnings.filterwarnings("ignore", category=FutureWarning) def set_seed(seed: int = 42): """Sets the seed for reproducibility across random, numpy, and torch.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def get_training_transforms() -> A.Compose: """Returns the albumentations training pipeline.""" return A.Compose([ A.ElasticTransform(p=0.25), A.RandomRotate90(p=0.5), A.Flip(p=0.5), A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5) ]) # --- Helper Classes --- class MetricTracker: """Accumulates targets and predictions to calculate epoch-level metrics.""" def __init__(self): self.reset() def reset(self): self.all_targets = [] self.all_predictions = [] self.total_loss = 0.0 self.steps = 0 def update(self, loss: float, targets: torch.Tensor, probabilities: torch.Tensor): self.total_loss += loss self.steps += 1 # Store detached cpu numpy arrays to avoid VRAM leaks self.all_targets.extend(torch.argmax(targets, dim=1).detach().cpu().numpy()) self.all_predictions.extend(torch.argmax(probabilities, dim=1).detach().cpu().numpy()) def compute(self) -> Dict[str, float]: """Calculates aggregate metrics for the accumulated data.""" if not self.all_targets: return {} # Calculate Confusion Matrix elements tn, fp, fn, tp = confusion_matrix(self.all_targets, self.all_predictions, labels=[0, 1]).ravel() return { "Loss": self.total_loss / max(self.steps, 1), "Accuracy": accuracy_score(self.all_targets, self.all_predictions), "Specificity": tn / (tn + fp) if (tn + fp) != 0 else 0.0, "Sensitivity": recall_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0), "F1": f1_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0), "MCC": matthews_corrcoef(self.all_targets, self.all_predictions), } class MethaneTrainer: """ Handles the training lifecycle: Model setup, Training loop, Validation, and Checkpointing. """ def __init__(self, args: argparse.Namespace): self.args = args self.device = "cuda" if torch.cuda.is_available() else "cpu" self.save_dir = Path(args.save_dir) / f'fold{args.test_fold}' self.save_dir.mkdir(parents=True, exist_ok=True) self.model = self._init_model() self.optimizer, self.scheduler = self._init_optimizer() self.criterion = self.task.criterion # Retrieved from the TerraTorch task self.best_val_loss = float('inf') logger.info(f"Trainer initialized on device: {self.device}") def _init_model(self) -> nn.Module: """Initializes the TerraTorch Classification Task and Model.""" model_config = dict( backbone="terramind_v1_base", backbone_pretrained=True, backbone_modalities=["S2L2A"], backbone_merge_method="mean", decoder="UperNetDecoder", decoder_scale_modules=True, decoder_channels=256, num_classes=2, head_dropout=0.3, necks=[ {"name": "ReshapeTokensToImage", "remove_cls_token": False}, {"name": "SelectIndices", "indices": [2, 5, 8, 11]}, ], ) self.task = ClassificationTask( model_args=model_config, model_factory="EncoderDecoderFactory", loss="ce", lr=self.args.lr, ignore_index=-1, optimizer="AdamW", optimizer_hparams={"weight_decay": self.args.weight_decay}, ) self.task.configure_models() self.task.configure_losses() return self.task.model.to(self.device) def _init_optimizer(self): optimizer = optim.AdamW( self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay ) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', patience=5, verbose=True ) return optimizer, scheduler def run_epoch(self, dataloader: DataLoader, stage: str = "train") -> Dict[str, float]: """Runs a single epoch for either training or validation.""" is_train = stage == "train" self.model.train() if is_train else self.model.eval() tracker = MetricTracker() # Context manager: enable grad only if training with torch.set_grad_enabled(is_train): pbar = tqdm(dataloader, desc=f" {stage.capitalize()}", leave=False) for batch in pbar: inputs = batch['S2L2A'].to(self.device) targets = batch['label'].to(self.device) # Forward Pass outputs = self.model(x={"S2L2A": inputs}) probabilities = torch.softmax(outputs.output, dim=1) loss = self.criterion(probabilities, targets) if is_train: self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update metrics tracker.update(loss.item(), targets, probabilities) # Update progress bar description with live loss pbar.set_postfix(loss=f"{loss.item():.4f}") return tracker.compute() def save_checkpoint(self, filename: str): path = self.save_dir / filename torch.save(self.model.state_dict(), path) logger.info(f"Saved model to {path}") def log_to_csv(self, epoch: int, train_metrics: Dict, val_metrics: Dict): """Appends metrics to the CSV log file.""" csv_path = self.save_dir / 'train_val_metrics.csv' file_exists = csv_path.exists() # Define headers based on metric keys headers = ['Epoch'] + [f'Train_{k}' for k in train_metrics.keys()] + [f'Val_{k}' for k in val_metrics.keys()] with open(csv_path, mode='a', newline='') as f: writer = csv.writer(f) if not file_exists: writer.writerow(headers) row = [epoch] + list(train_metrics.values()) + list(val_metrics.values()) writer.writerow(row) def fit(self, train_loader: DataLoader, val_loader: DataLoader): """Main training entry point.""" logger.info(f"Starting training for {self.args.epochs} epochs...") start_time = time.time() for epoch in range(1, self.args.epochs + 1): logger.info(f"Epoch {epoch}/{self.args.epochs}") # Run Training & Validation train_metrics = self.run_epoch(train_loader, stage="train") val_metrics = self.run_epoch(val_loader, stage="validate") # Scheduler Step self.scheduler.step(val_metrics['Loss']) # Logging self.log_to_csv(epoch, train_metrics, val_metrics) logger.info( f"Train Loss: {train_metrics['Loss']:.4f} | " f"Val Loss: {val_metrics['Loss']:.4f} | " f"Val F1: {val_metrics['F1']:.4f}" ) # Save Best Model if val_metrics['Loss'] < self.best_val_loss: self.best_val_loss = val_metrics['Loss'] self.save_checkpoint("best_model.pth") logger.info(f"--> New best model (Val Loss: {self.best_val_loss:.4f})") # End of training self.save_checkpoint("final_model.pth") logger.info(f"Training finished in {time.time() - start_time:.2f}s") # --- Data Utilities --- def get_data_loaders(args) -> Tuple[DataLoader, DataLoader]: """Prepares DataModule and returns Train/Val loaders.""" # Read Excel and Filter Folds try: df = pd.read_csv(args.excel_file) if args.excel_file.endswith('.csv') else pd.read_excel(args.excel_file) except Exception as e: logger.error(f"Failed to load summary file: {e}") raise # Determine training pool (all folds except test_fold) all_folds = range(1, args.num_folds + 1) train_pool_folds = [f for f in all_folds if f != args.test_fold] # Filter filenames df_filtered = df[df['Fold'].isin(train_pool_folds)] if df_filtered.empty: raise ValueError(f"No data found for folds {train_pool_folds}. Check 'Fold' column in Excel.") paths = df_filtered['Filename'].tolist() # 80/20 Split train_paths, val_paths = train_test_split(paths, test_size=0.2, random_state=args.seed) logger.info(f"Data Split - Train: {len(train_paths)}, Val: {len(val_paths)} (Test Fold: {args.test_fold})") # Initialize DataModule datamodule = MethaneClassificationDataModule( data_root=args.root_dir, excel_file=args.excel_file, batch_size=args.batch_size, paths=train_paths, train_transform=get_training_transforms(), val_transform=None, ) # Create Loaders datamodule.paths = train_paths datamodule.setup(stage="fit") train_loader = datamodule.train_dataloader() datamodule.paths = val_paths datamodule.setup(stage="validate") val_loader = datamodule.val_dataloader() return train_loader, val_loader # --- Main Execution --- def parse_args(): parser = argparse.ArgumentParser(description="Methane Classification Training with TerraTorch") # Paths parser.add_argument('--root_dir', type=str, required=True, help='Root directory for satellite images') parser.add_argument('--excel_file', type=str, required=True, help='Path to summary Excel/CSV file') parser.add_argument('--save_dir', type=str, default='./checkpoints', help='Directory to save outputs') # Training Config parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--batch_size', type=int, default=8) parser.add_argument('--lr', type=float, default=1e-5) parser.add_argument('--weight_decay', type=float, default=0.05) parser.add_argument('--num_folds', type=int, default=5) parser.add_argument('--test_fold', type=int, default=2, help='Fold ID to hold out for testing') parser.add_argument('--seed', type=int, default=42) return parser.parse_args() if __name__ == "__main__": args = parse_args() set_seed(args.seed) # Prepare Data train_loader, val_loader = get_data_loaders(args) # Initialize Trainer and Start trainer = MethaneTrainer(args) trainer.fit(train_loader, val_loader)