TerraMind-Methane-Classification / classification /script /train_classification_fine_tuning.py
KPLabs's picture
Upload folder using huggingface_hub
97a17c2 verified
import argparse
import logging
import csv
import random
import warnings
import time
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import albumentations as A
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
accuracy_score, recall_score, f1_score, matthews_corrcoef, confusion_matrix
)
from rasterio.errors import NotGeoreferencedWarning
import terramind
# Local Imports
from methane_classification_datamodule import MethaneClassificationDataModule
# TerraTorch Imports
from terratorch.tasks import ClassificationTask
# --- Configuration & Setup ---
# Configure Logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
# Suppress Warnings
logging.getLogger("rasterio._env").setLevel(logging.ERROR)
warnings.simplefilter("ignore", NotGeoreferencedWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
def set_seed(seed: int = 42):
"""Sets the seed for reproducibility across random, numpy, and torch."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def get_training_transforms() -> A.Compose:
"""Returns the albumentations training pipeline."""
return A.Compose([
A.ElasticTransform(p=0.25),
A.RandomRotate90(p=0.5),
A.Flip(p=0.5),
A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5)
])
# --- Helper Classes ---
class MetricTracker:
"""Accumulates targets and predictions to calculate epoch-level metrics."""
def __init__(self):
self.reset()
def reset(self):
self.all_targets = []
self.all_predictions = []
self.total_loss = 0.0
self.steps = 0
def update(self, loss: float, targets: torch.Tensor, probabilities: torch.Tensor):
self.total_loss += loss
self.steps += 1
# Store detached cpu numpy arrays to avoid VRAM leaks
self.all_targets.extend(torch.argmax(targets, dim=1).detach().cpu().numpy())
self.all_predictions.extend(torch.argmax(probabilities, dim=1).detach().cpu().numpy())
def compute(self) -> Dict[str, float]:
"""Calculates aggregate metrics for the accumulated data."""
if not self.all_targets:
return {}
# Calculate Confusion Matrix elements
tn, fp, fn, tp = confusion_matrix(self.all_targets, self.all_predictions, labels=[0, 1]).ravel()
return {
"Loss": self.total_loss / max(self.steps, 1),
"Accuracy": accuracy_score(self.all_targets, self.all_predictions),
"Specificity": tn / (tn + fp) if (tn + fp) != 0 else 0.0,
"Sensitivity": recall_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0),
"F1": f1_score(self.all_targets, self.all_predictions, average='binary', pos_label=1, zero_division=0),
"MCC": matthews_corrcoef(self.all_targets, self.all_predictions),
}
class MethaneTrainer:
"""
Handles the training lifecycle: Model setup, Training loop, Validation, and Checkpointing.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.save_dir = Path(args.save_dir) / f'fold{args.test_fold}'
self.save_dir.mkdir(parents=True, exist_ok=True)
self.model = self._init_model()
self.optimizer, self.scheduler = self._init_optimizer()
self.criterion = self.task.criterion # Retrieved from the TerraTorch task
self.best_val_loss = float('inf')
logger.info(f"Trainer initialized on device: {self.device}")
def _init_model(self) -> nn.Module:
"""Initializes the TerraTorch Classification Task and Model."""
model_config = dict(
backbone="terramind_v1_base",
backbone_pretrained=True,
backbone_modalities=["S2L2A"],
backbone_merge_method="mean",
decoder="UperNetDecoder",
decoder_scale_modules=True,
decoder_channels=256,
num_classes=2,
head_dropout=0.3,
necks=[
{"name": "ReshapeTokensToImage", "remove_cls_token": False},
{"name": "SelectIndices", "indices": [2, 5, 8, 11]},
],
)
self.task = ClassificationTask(
model_args=model_config,
model_factory="EncoderDecoderFactory",
loss="ce",
lr=self.args.lr,
ignore_index=-1,
optimizer="AdamW",
optimizer_hparams={"weight_decay": self.args.weight_decay},
)
self.task.configure_models()
self.task.configure_losses()
return self.task.model.to(self.device)
def _init_optimizer(self):
optimizer = optim.AdamW(
self.model.parameters(),
lr=self.args.lr,
weight_decay=self.args.weight_decay
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', patience=5, verbose=True
)
return optimizer, scheduler
def run_epoch(self, dataloader: DataLoader, stage: str = "train") -> Dict[str, float]:
"""Runs a single epoch for either training or validation."""
is_train = stage == "train"
self.model.train() if is_train else self.model.eval()
tracker = MetricTracker()
# Context manager: enable grad only if training
with torch.set_grad_enabled(is_train):
pbar = tqdm(dataloader, desc=f" {stage.capitalize()}", leave=False)
for batch in pbar:
inputs = batch['S2L2A'].to(self.device)
targets = batch['label'].to(self.device)
# Forward Pass
outputs = self.model(x={"S2L2A": inputs})
probabilities = torch.softmax(outputs.output, dim=1)
loss = self.criterion(probabilities, targets)
if is_train:
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Update metrics
tracker.update(loss.item(), targets, probabilities)
# Update progress bar description with live loss
pbar.set_postfix(loss=f"{loss.item():.4f}")
return tracker.compute()
def save_checkpoint(self, filename: str):
path = self.save_dir / filename
torch.save(self.model.state_dict(), path)
logger.info(f"Saved model to {path}")
def log_to_csv(self, epoch: int, train_metrics: Dict, val_metrics: Dict):
"""Appends metrics to the CSV log file."""
csv_path = self.save_dir / 'train_val_metrics.csv'
file_exists = csv_path.exists()
# Define headers based on metric keys
headers = ['Epoch'] + [f'Train_{k}' for k in train_metrics.keys()] + [f'Val_{k}' for k in val_metrics.keys()]
with open(csv_path, mode='a', newline='') as f:
writer = csv.writer(f)
if not file_exists:
writer.writerow(headers)
row = [epoch] + list(train_metrics.values()) + list(val_metrics.values())
writer.writerow(row)
def fit(self, train_loader: DataLoader, val_loader: DataLoader):
"""Main training entry point."""
logger.info(f"Starting training for {self.args.epochs} epochs...")
start_time = time.time()
for epoch in range(1, self.args.epochs + 1):
logger.info(f"Epoch {epoch}/{self.args.epochs}")
# Run Training & Validation
train_metrics = self.run_epoch(train_loader, stage="train")
val_metrics = self.run_epoch(val_loader, stage="validate")
# Scheduler Step
self.scheduler.step(val_metrics['Loss'])
# Logging
self.log_to_csv(epoch, train_metrics, val_metrics)
logger.info(
f"Train Loss: {train_metrics['Loss']:.4f} | "
f"Val Loss: {val_metrics['Loss']:.4f} | "
f"Val F1: {val_metrics['F1']:.4f}"
)
# Save Best Model
if val_metrics['Loss'] < self.best_val_loss:
self.best_val_loss = val_metrics['Loss']
self.save_checkpoint("best_model.pth")
logger.info(f"--> New best model (Val Loss: {self.best_val_loss:.4f})")
# End of training
self.save_checkpoint("final_model.pth")
logger.info(f"Training finished in {time.time() - start_time:.2f}s")
# --- Data Utilities ---
def get_data_loaders(args) -> Tuple[DataLoader, DataLoader]:
"""Prepares DataModule and returns Train/Val loaders."""
# Read Excel and Filter Folds
try:
df = pd.read_csv(args.excel_file) if args.excel_file.endswith('.csv') else pd.read_excel(args.excel_file)
except Exception as e:
logger.error(f"Failed to load summary file: {e}")
raise
# Determine training pool (all folds except test_fold)
all_folds = range(1, args.num_folds + 1)
train_pool_folds = [f for f in all_folds if f != args.test_fold]
# Filter filenames
df_filtered = df[df['Fold'].isin(train_pool_folds)]
if df_filtered.empty:
raise ValueError(f"No data found for folds {train_pool_folds}. Check 'Fold' column in Excel.")
paths = df_filtered['Filename'].tolist()
# 80/20 Split
train_paths, val_paths = train_test_split(paths, test_size=0.2, random_state=args.seed)
logger.info(f"Data Split - Train: {len(train_paths)}, Val: {len(val_paths)} (Test Fold: {args.test_fold})")
# Initialize DataModule
datamodule = MethaneClassificationDataModule(
data_root=args.root_dir,
excel_file=args.excel_file,
batch_size=args.batch_size,
paths=train_paths,
train_transform=get_training_transforms(),
val_transform=None,
)
# Create Loaders
datamodule.paths = train_paths
datamodule.setup(stage="fit")
train_loader = datamodule.train_dataloader()
datamodule.paths = val_paths
datamodule.setup(stage="validate")
val_loader = datamodule.val_dataloader()
return train_loader, val_loader
# --- Main Execution ---
def parse_args():
parser = argparse.ArgumentParser(description="Methane Classification Training with TerraTorch")
# Paths
parser.add_argument('--root_dir', type=str, required=True, help='Root directory for satellite images')
parser.add_argument('--excel_file', type=str, required=True, help='Path to summary Excel/CSV file')
parser.add_argument('--save_dir', type=str, default='./checkpoints', help='Directory to save outputs')
# Training Config
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--weight_decay', type=float, default=0.05)
parser.add_argument('--num_folds', type=int, default=5)
parser.add_argument('--test_fold', type=int, default=2, help='Fold ID to hold out for testing')
parser.add_argument('--seed', type=int, default=42)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
set_seed(args.seed)
# Prepare Data
train_loader, val_loader = get_data_loaders(args)
# Initialize Trainer and Start
trainer = MethaneTrainer(args)
trainer.fit(train_loader, val_loader)