import argparse import logging import csv import random import warnings import time import os from pathlib import Path from typing import Dict, List, Tuple, Any, Optional import numpy as np import pandas as pd import torch import torch.nn as nn import albumentations as A from torch.utils.data import DataLoader from tqdm import tqdm from rasterio.errors import NotGeoreferencedWarning # --- CRITICAL IMPORTS --- import terramind from terratorch.tasks import ClassificationTask # Local Imports from methane_simulated_datamodule import MethaneSimulatedDataModule # --- Configuration & Setup --- logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger(__name__) logging.getLogger("rasterio._env").setLevel(logging.ERROR) warnings.simplefilter("ignore", NotGeoreferencedWarning) warnings.filterwarnings("ignore", category=FutureWarning) def set_seed(seed: int = 42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def get_inference_transforms() -> A.Compose: return None # --- Path Utilities (Crucial for Simulated Data) --- def get_simulated_paths(paths: List[str]) -> List[str]: """ Modifies filenames to match the simulated dataset naming convention. Original: 'MBD_0001_S2_...' -> Simulated: 'MBD_toarefl_S2_...' """ simulated_paths = [] for path in paths: try: tokens = path.split('_') # Reconstruct filename based on notebook logic if len(tokens) >= 5: # e.g., MBD_toarefl_S2_123_456 simulated_path = f"{tokens[0]}_toarefl_{tokens[3]}_{tokens[4]}" simulated_paths.append(simulated_path) else: simulated_paths.append(path) except Exception as e: logger.warning(f"Could not parse path {path}: {e}") simulated_paths.append(path) return simulated_paths # --- Inference Class --- class SimulatedInference: def __init__(self, args: argparse.Namespace): self.args = args self.device = "cuda" if torch.cuda.is_available() else "cpu" self.output_dir = Path(args.output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Initializing Inference on device: {self.device}") self.model = self._init_model() self._load_checkpoint(args.checkpoint) def _init_model(self) -> nn.Module: model_args = dict( backbone="terramind_v1_base", backbone_pretrained=False, backbone_modalities=["S2L2A"], backbone_merge_method="mean", decoder="UperNetDecoder", decoder_scale_modules=True, decoder_channels=256, num_classes=2, head_dropout=0.3, necks=[ {"name": "ReshapeTokensToImage", "remove_cls_token": False}, {"name": "SelectIndices", "indices": [2, 5, 8, 11]}, {"name": "LearnedInterpolateToPyramidal"}, ], ) task = ClassificationTask( model_args=model_args, model_factory="EncoderDecoderFactory", loss="ce", ignore_index=-1 ) task.configure_models() return task.model.to(self.device) def _load_checkpoint(self, checkpoint_path: str): path = Path(checkpoint_path) if not path.exists(): raise FileNotFoundError(f"Checkpoint not found at {path}") logger.info(f"Loading weights from {path}...") checkpoint = torch.load(path, map_location=self.device) if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint self.model.load_state_dict(state_dict, strict=False) self.model.eval() def run_inference(self, dataloader: DataLoader, sample_names: List[str]): """ Generates predictions and matches them with provided sample identifiers. """ results = {} logger.info(f"Starting inference on {len(sample_names)} samples...") # Iterator to match predictions with original filenames name_iter = iter(sample_names) with torch.no_grad(): for batch in tqdm(dataloader, desc="Predicting"): inputs = batch['S2L2A'].to(self.device) # Forward Pass outputs = self.model(x={"S2L2A": inputs}) probabilities = torch.softmax(outputs.output, dim=1) # Get binary prediction (0 or 1) predictions = torch.argmax(probabilities, dim=1) batch_preds = predictions.cpu().numpy() # Assign to Sample Names for pred in batch_preds: try: sample_name = next(name_iter) results[sample_name] = int(pred) except StopIteration: logger.error("More predictions than sample names! Check sync.") break if len(results) != len(sample_names): logger.warning(f"Mismatch: Expected {len(sample_names)} results, got {len(results)}.") # Save CSV self._save_results(results) def _save_results(self, results: Dict[str, int]): csv_path = self.output_dir / "simulated_predictions.csv" with open(csv_path, mode='w', newline='') as f: writer = csv.writer(f) writer.writerow(['Sample_ID', 'Prediction']) for sample, pred in results.items(): writer.writerow([sample, pred]) logger.info(f"Predictions saved to {csv_path}") # --- Data Loading --- def get_dataloader_and_names(args) -> Tuple[DataLoader, List[str]]: # 1. Read Excel to get base filenames try: df = pd.read_excel(args.excel_file) # If specific folds are requested, filter them if args.folds: folds_to_use = [int(f) for f in args.folds.split(',')] df = df[df['Fold'].isin(folds_to_use)] logger.info(f"Filtered to folds: {folds_to_use}") raw_paths = df['Filename'].tolist() logger.info(f"Loaded {len(raw_paths)} paths from Excel.") except Exception as e: logger.error(f"Error reading Excel: {e}") raise # 2. Transform paths to Simulated format simulated_paths = get_simulated_paths(raw_paths) # 3. Initialize DataModule datamodule = MethaneSimulatedDataModule( data_root=args.root_dir, excel_file=args.excel_file, batch_size=args.batch_size, paths=simulated_paths, train_transform=None, val_transform=get_inference_transforms(), ) # 4. Setup datamodule.paths = simulated_paths datamodule.setup(stage="test") # Try getting test_dataloader, else train/val loader = datamodule.test_dataloader() if hasattr(datamodule, 'test_dataloader') else datamodule.train_dataloader() return loader, simulated_paths # --- Main Execution --- def parse_args(): parser = argparse.ArgumentParser(description="Methane Simulated S2 Inference") parser.add_argument('--root_dir', type=str, required=True, help='Root directory for simulated data') parser.add_argument('--excel_file', type=str, required=True, help='Path to Summary Excel file') parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint (.pth)') parser.add_argument('--output_dir', type=str, default='./inference_results', help='Directory to save results') parser.add_argument('--folds', type=str, default=None, help='Comma-separated list of folds to infer on (e.g., "4" or "1,2"). If None, uses all.') parser.add_argument('--batch_size', type=int, default=1, help='Inference batch size') parser.add_argument('--seed', type=int, default=42) return parser.parse_args() if __name__ == "__main__": args = parse_args() set_seed(args.seed) # 1. Prepare Data & Names dataloader, sample_names = get_dataloader_and_names(args) # 2. Run Inference engine = SimulatedInference(args) engine.run_inference(dataloader, sample_names)