TerraMind-Methane-Classification / sentinel2_classification_finetuning /script /inference_s2_simulated.py
| import argparse | |
| import logging | |
| import csv | |
| import random | |
| import warnings | |
| import time | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple, Any, Optional | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import albumentations as A | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from rasterio.errors import NotGeoreferencedWarning | |
| # --- CRITICAL IMPORTS --- | |
| import terramind | |
| from terratorch.tasks import ClassificationTask | |
| # Local Imports | |
| from methane_simulated_datamodule import MethaneSimulatedDataModule | |
| # --- Configuration & Setup --- | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| logging.getLogger("rasterio._env").setLevel(logging.ERROR) | |
| warnings.simplefilter("ignore", NotGeoreferencedWarning) | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| def set_seed(seed: int = 42): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| def get_inference_transforms() -> A.Compose: | |
| return None | |
| # --- Path Utilities (Crucial for Simulated Data) --- | |
| def get_simulated_paths(paths: List[str]) -> List[str]: | |
| """ | |
| Modifies filenames to match the simulated dataset naming convention. | |
| Original: 'MBD_0001_S2_...' -> Simulated: 'MBD_toarefl_S2_...' | |
| """ | |
| simulated_paths = [] | |
| for path in paths: | |
| try: | |
| tokens = path.split('_') | |
| # Reconstruct filename based on notebook logic | |
| if len(tokens) >= 5: | |
| # e.g., MBD_toarefl_S2_123_456 | |
| simulated_path = f"{tokens[0]}_toarefl_{tokens[3]}_{tokens[4]}" | |
| simulated_paths.append(simulated_path) | |
| else: | |
| simulated_paths.append(path) | |
| except Exception as e: | |
| logger.warning(f"Could not parse path {path}: {e}") | |
| simulated_paths.append(path) | |
| return simulated_paths | |
| # --- Inference Class --- | |
| class SimulatedInference: | |
| def __init__(self, args: argparse.Namespace): | |
| self.args = args | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.output_dir = Path(args.output_dir) | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| logger.info(f"Initializing Inference on device: {self.device}") | |
| self.model = self._init_model() | |
| self._load_checkpoint(args.checkpoint) | |
| def _init_model(self) -> nn.Module: | |
| model_args = dict( | |
| backbone="terramind_v1_base", | |
| backbone_pretrained=False, | |
| backbone_modalities=["S2L2A"], | |
| backbone_merge_method="mean", | |
| decoder="UperNetDecoder", | |
| decoder_scale_modules=True, | |
| decoder_channels=256, | |
| num_classes=2, | |
| head_dropout=0.3, | |
| necks=[ | |
| {"name": "ReshapeTokensToImage", "remove_cls_token": False}, | |
| {"name": "SelectIndices", "indices": [2, 5, 8, 11]}, | |
| {"name": "LearnedInterpolateToPyramidal"}, | |
| ], | |
| ) | |
| task = ClassificationTask( | |
| model_args=model_args, | |
| model_factory="EncoderDecoderFactory", | |
| loss="ce", | |
| ignore_index=-1 | |
| ) | |
| task.configure_models() | |
| return task.model.to(self.device) | |
| def _load_checkpoint(self, checkpoint_path: str): | |
| path = Path(checkpoint_path) | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Checkpoint not found at {path}") | |
| logger.info(f"Loading weights from {path}...") | |
| checkpoint = torch.load(path, map_location=self.device) | |
| if 'state_dict' in checkpoint: | |
| state_dict = checkpoint['state_dict'] | |
| else: | |
| state_dict = checkpoint | |
| self.model.load_state_dict(state_dict, strict=False) | |
| self.model.eval() | |
| def run_inference(self, dataloader: DataLoader, sample_names: List[str]): | |
| """ | |
| Generates predictions and matches them with provided sample identifiers. | |
| """ | |
| results = {} | |
| logger.info(f"Starting inference on {len(sample_names)} samples...") | |
| # Iterator to match predictions with original filenames | |
| name_iter = iter(sample_names) | |
| with torch.no_grad(): | |
| for batch in tqdm(dataloader, desc="Predicting"): | |
| inputs = batch['S2L2A'].to(self.device) | |
| # Forward Pass | |
| outputs = self.model(x={"S2L2A": inputs}) | |
| probabilities = torch.softmax(outputs.output, dim=1) | |
| # Get binary prediction (0 or 1) | |
| predictions = torch.argmax(probabilities, dim=1) | |
| batch_preds = predictions.cpu().numpy() | |
| # Assign to Sample Names | |
| for pred in batch_preds: | |
| try: | |
| sample_name = next(name_iter) | |
| results[sample_name] = int(pred) | |
| except StopIteration: | |
| logger.error("More predictions than sample names! Check sync.") | |
| break | |
| if len(results) != len(sample_names): | |
| logger.warning(f"Mismatch: Expected {len(sample_names)} results, got {len(results)}.") | |
| # Save CSV | |
| self._save_results(results) | |
| def _save_results(self, results: Dict[str, int]): | |
| csv_path = self.output_dir / "simulated_predictions.csv" | |
| with open(csv_path, mode='w', newline='') as f: | |
| writer = csv.writer(f) | |
| writer.writerow(['Sample_ID', 'Prediction']) | |
| for sample, pred in results.items(): | |
| writer.writerow([sample, pred]) | |
| logger.info(f"Predictions saved to {csv_path}") | |
| # --- Data Loading --- | |
| def get_dataloader_and_names(args) -> Tuple[DataLoader, List[str]]: | |
| # 1. Read Excel to get base filenames | |
| try: | |
| df = pd.read_excel(args.excel_file) | |
| # If specific folds are requested, filter them | |
| if args.folds: | |
| folds_to_use = [int(f) for f in args.folds.split(',')] | |
| df = df[df['Fold'].isin(folds_to_use)] | |
| logger.info(f"Filtered to folds: {folds_to_use}") | |
| raw_paths = df['Filename'].tolist() | |
| logger.info(f"Loaded {len(raw_paths)} paths from Excel.") | |
| except Exception as e: | |
| logger.error(f"Error reading Excel: {e}") | |
| raise | |
| # 2. Transform paths to Simulated format | |
| simulated_paths = get_simulated_paths(raw_paths) | |
| # 3. Initialize DataModule | |
| datamodule = MethaneSimulatedDataModule( | |
| data_root=args.root_dir, | |
| excel_file=args.excel_file, | |
| batch_size=args.batch_size, | |
| paths=simulated_paths, | |
| train_transform=None, | |
| val_transform=get_inference_transforms(), | |
| ) | |
| # 4. Setup | |
| datamodule.paths = simulated_paths | |
| datamodule.setup(stage="test") | |
| # Try getting test_dataloader, else train/val | |
| loader = datamodule.test_dataloader() if hasattr(datamodule, 'test_dataloader') else datamodule.train_dataloader() | |
| return loader, simulated_paths | |
| # --- Main Execution --- | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Methane Simulated S2 Inference") | |
| parser.add_argument('--root_dir', type=str, required=True, help='Root directory for simulated data') | |
| parser.add_argument('--excel_file', type=str, required=True, help='Path to Summary Excel file') | |
| parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint (.pth)') | |
| parser.add_argument('--output_dir', type=str, default='./inference_results', help='Directory to save results') | |
| parser.add_argument('--folds', type=str, default=None, help='Comma-separated list of folds to infer on (e.g., "4" or "1,2"). If None, uses all.') | |
| parser.add_argument('--batch_size', type=int, default=1, help='Inference batch size') | |
| parser.add_argument('--seed', type=int, default=42) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| set_seed(args.seed) | |
| # 1. Prepare Data & Names | |
| dataloader, sample_names = get_dataloader_and_names(args) | |
| # 2. Run Inference | |
| engine = SimulatedInference(args) | |
| engine.run_inference(dataloader, sample_names) |