TerraMind-Methane-Classification / intuition1_classification_finetuning /script /train_simulated_I1.py
| import argparse | |
| import logging | |
| import csv | |
| import random | |
| import warnings | |
| import time | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple, Any, Optional | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import albumentations as A | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import ( | |
| accuracy_score, recall_score, f1_score, matthews_corrcoef, confusion_matrix | |
| ) | |
| from rasterio.errors import NotGeoreferencedWarning | |
| # --- CRITICAL IMPORTS --- | |
| import terramind | |
| from terratorch.tasks import ClassificationTask | |
| # Local Imports | |
| from methane_simulated_datamodule import MethaneSimulatedDataModule | |
| # --- Configuration & Setup --- | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| logging.getLogger("rasterio._env").setLevel(logging.ERROR) | |
| warnings.simplefilter("ignore", NotGeoreferencedWarning) | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| def set_seed(seed: int = 42): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| def get_training_transforms() -> A.Compose: | |
| return A.Compose([ | |
| A.ElasticTransform(p=0.25), | |
| A.RandomRotate90(p=0.5), | |
| A.Flip(p=0.5), | |
| A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5) | |
| ]) | |
| # --- Path Utilities --- | |
| def get_simulated_paths(paths: List[str]) -> List[str]: | |
| """ | |
| Modifies filenames to match the I1/TOA naming convention. | |
| Converts 'ang2015..._S2_...' -> 'ang2015..._toarefl_...' | |
| """ | |
| simulated_paths = [] | |
| for path in paths: | |
| try: | |
| tokens = path.split('_') | |
| # Logic: {ID}_toarefl_{Coord1}_{Coord2} | |
| # Adjusts original filename tokens to target format | |
| if len(tokens) >= 5: | |
| simulated_path = f"{tokens[0]}_toarefl_{tokens[3]}_{tokens[4]}" | |
| simulated_paths.append(simulated_path) | |
| else: | |
| simulated_paths.append(path) | |
| except Exception as e: | |
| logger.warning(f"Could not parse path {path}: {e}") | |
| simulated_paths.append(path) | |
| return simulated_paths | |
| def get_paths_for_fold(excel_file: str, folds: List[int]) -> List[str]: | |
| try: | |
| df = pd.read_excel(excel_file) | |
| df_filtered = df[df['Fold'].isin(folds)] | |
| return df_filtered['Filename'].tolist() | |
| except Exception as e: | |
| logger.error(f"Error reading Excel file: {e}") | |
| raise | |
| # --- Helper Classes --- | |
| class MetricTracker: | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.all_targets = [] | |
| self.all_predictions = [] | |
| self.total_loss = 0.0 | |
| self.steps = 0 | |
| def update(self, loss: float, targets: torch.Tensor, probabilities: torch.Tensor): | |
| self.total_loss += loss | |
| self.steps += 1 | |
| self.all_targets.extend(torch.argmax(targets, dim=1).detach().cpu().numpy()) | |
| self.all_predictions.extend(torch.argmax(probabilities, dim=1).detach().cpu().numpy()) | |
| def compute(self) -> Dict[str, float]: | |
| if not self.all_targets: | |
| return {} | |
| tn, fp, fn, tp = confusion_matrix(self.all_targets, self.all_predictions, labels=[0, 1]).ravel() | |
| return { | |
| "Loss": self.total_loss / max(self.steps, 1), | |
| "Accuracy": accuracy_score(self.all_targets, self.all_predictions), | |
| "Specificity": tn / (tn + fp) if (tn + fp) != 0 else 0.0, | |
| "Sensitivity": recall_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0), | |
| "F1": f1_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0), | |
| "MCC": matthews_corrcoef(self.all_targets, self.all_predictions), | |
| "TP": int(tp), "TN": int(tn), "FP": int(fp), "FN": int(fn) | |
| } | |
| class TrainerI1: | |
| def __init__(self, args: argparse.Namespace): | |
| self.args = args | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.save_dir = Path(args.save_dir) / f'fold{args.test_fold}' | |
| self.save_dir.mkdir(parents=True, exist_ok=True) | |
| self.model = self._init_model() | |
| self.optimizer, self.scheduler = self._init_optimizer() | |
| self.criterion = self.task.criterion | |
| self.best_val_loss = float('inf') | |
| logger.info(f"Trainer initialized on device: {self.device}") | |
| def _init_model(self) -> nn.Module: | |
| model_args = dict( | |
| backbone="terramind_v1_base", | |
| backbone_pretrained=True, | |
| backbone_modalities=["S2L2A"], | |
| backbone_merge_method="mean", | |
| decoder="UperNetDecoder", | |
| decoder_scale_modules=True, | |
| decoder_channels=256, | |
| num_classes=2, | |
| head_dropout=0.3, | |
| necks=[ | |
| {"name": "ReshapeTokensToImage", "remove_cls_token": False}, | |
| {"name": "SelectIndices", "indices": [2, 5, 8, 11]}, | |
| {"name": "LearnedInterpolateToPyramidal"}, | |
| ], | |
| ) | |
| self.task = ClassificationTask( | |
| model_args=model_args, | |
| model_factory="EncoderDecoderFactory", | |
| loss="ce", | |
| lr=self.args.lr, | |
| ignore_index=-1, | |
| optimizer="AdamW", | |
| optimizer_hparams={"weight_decay": self.args.weight_decay}, | |
| ) | |
| self.task.configure_models() | |
| self.task.configure_losses() | |
| return self.task.model.to(self.device) | |
| def _init_optimizer(self): | |
| optimizer = optim.AdamW(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) | |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, verbose=True) | |
| return optimizer, scheduler | |
| def run_epoch(self, dataloader: DataLoader, stage: str = "train") -> Dict[str, float]: | |
| is_train = stage == "train" | |
| self.model.train() if is_train else self.model.eval() | |
| tracker = MetricTracker() | |
| with torch.set_grad_enabled(is_train): | |
| pbar = tqdm(dataloader, desc=f" {stage.capitalize()}", leave=False) | |
| for batch in pbar: | |
| inputs = batch['S2L2A'].to(self.device) | |
| targets = batch['label'].to(self.device) | |
| outputs = self.model(x={"S2L2A": inputs}) | |
| probabilities = torch.softmax(outputs.output, dim=1) | |
| loss = self.criterion(probabilities, targets) | |
| if is_train: | |
| self.optimizer.zero_grad() | |
| loss.backward() | |
| self.optimizer.step() | |
| tracker.update(loss.item(), targets, probabilities) | |
| pbar.set_postfix(loss=f"{loss.item():.4f}") | |
| return tracker.compute() | |
| def fit(self, train_loader: DataLoader, val_loader: DataLoader): | |
| logger.info(f"Starting training for {self.args.epochs} epochs...") | |
| start_time = time.time() | |
| # Initialize CSV logging | |
| csv_path = self.save_dir / 'train_val_metrics.csv' | |
| with open(csv_path, 'w', newline='') as f: | |
| writer = csv.writer(f) | |
| writer.writerow([ | |
| 'Epoch', 'Train_Loss', 'Train_F1', 'Train_Acc', | |
| 'Val_Loss', 'Val_F1', 'Val_Acc', 'Val_Spec', 'Val_Sens' | |
| ]) | |
| for epoch in range(1, self.args.epochs + 1): | |
| logger.info(f"Epoch {epoch}/{self.args.epochs}") | |
| train_metrics = self.run_epoch(train_loader, stage="train") | |
| val_metrics = self.run_epoch(val_loader, stage="validate") | |
| self.scheduler.step(val_metrics['Loss']) | |
| # Log to CSV | |
| with open(csv_path, 'a', newline='') as f: | |
| writer = csv.writer(f) | |
| writer.writerow([ | |
| epoch, | |
| train_metrics.get('Loss'), train_metrics.get('F1'), train_metrics.get('Accuracy'), | |
| val_metrics.get('Loss'), val_metrics.get('F1'), val_metrics.get('Accuracy'), | |
| val_metrics.get('Specificity'), val_metrics.get('Sensitivity') | |
| ]) | |
| logger.info(f"Train Loss: {train_metrics['Loss']:.4f} | Val Loss: {val_metrics['Loss']:.4f} | Val F1: {val_metrics['F1']:.4f}") | |
| # Save Best Model | |
| if val_metrics['Loss'] < self.best_val_loss: | |
| self.best_val_loss = val_metrics['Loss'] | |
| torch.save(self.model.state_dict(), self.save_dir / "best_model.pth") | |
| logger.info(f"--> New best model saved") | |
| # Save Final Model | |
| torch.save(self.model.state_dict(), self.save_dir / "final_model.pth") | |
| logger.info(f"Training finished in {time.time() - start_time:.2f}s") | |
| # --- Data Utilities --- | |
| def get_data_loaders(args) -> Tuple[DataLoader, DataLoader]: | |
| # 1. Determine Folds | |
| all_folds = list(range(1, args.num_folds + 1)) | |
| train_pool_folds = [f for f in all_folds if f != args.test_fold] | |
| # 2. Get Paths & Convert to TOA/I1 format | |
| # Note: Using get_simulated_paths to transform names as done in the notebook | |
| paths = get_paths_for_fold(args.excel_file, train_pool_folds) | |
| paths = get_simulated_paths(paths) | |
| # 3. Train/Val Split (80/20) | |
| train_paths, val_paths = train_test_split(paths, test_size=0.2, random_state=args.seed) | |
| logger.info(f"Data Split - Train: {len(train_paths)}, Val: {len(val_paths)} (Test Fold: {args.test_fold})") | |
| # 4. Initialize DataModule | |
| datamodule = MethaneSimulatedDataModule( | |
| data_root=args.root_dir, | |
| excel_file=args.excel_file, | |
| batch_size=args.batch_size, | |
| paths=paths, # Initial dummy | |
| train_transform=get_training_transforms(), | |
| val_transform=None, | |
| ) | |
| # 5. Create Loaders | |
| datamodule.paths = train_paths | |
| datamodule.setup(stage="fit") | |
| train_loader = datamodule.train_dataloader() | |
| datamodule.paths = val_paths | |
| datamodule.setup(stage="validate") | |
| val_loader = datamodule.val_dataloader() | |
| return train_loader, val_loader | |
| # --- Main Execution --- | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Methane I1 (TOA Refl) Training") | |
| # Paths | |
| parser.add_argument('--root_dir', type=str, required=True, help='Root directory for I1/TOA data') | |
| parser.add_argument('--excel_file', type=str, required=True, help='Path to Summary Excel') | |
| parser.add_argument('--save_dir', type=str, default='./checkpoints_i1', help='Output directory') | |
| # Hyperparameters | |
| parser.add_argument('--epochs', type=int, default=100) | |
| parser.add_argument('--batch_size', type=int, default=1) | |
| parser.add_argument('--lr', type=float, default=1e-5) | |
| parser.add_argument('--weight_decay', type=float, default=0.05) | |
| parser.add_argument('--num_folds', type=int, default=5) | |
| parser.add_argument('--test_fold', type=int, default=4, help='Fold ID to hold out') | |
| parser.add_argument('--seed', type=int, default=42) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| set_seed(args.seed) | |
| train_loader, val_loader = get_data_loaders(args) | |
| trainer = TrainerI1(args) | |
| trainer.fit(train_loader, val_loader) |