| import argparse |
| import logging |
| import csv |
| import random |
| import warnings |
| import time |
| import os |
| from pathlib import Path |
| from typing import Dict, List, Tuple, Any, Optional |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| import albumentations as A |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| from rasterio.errors import NotGeoreferencedWarning |
|
|
| |
| import terramind |
| from terratorch.tasks import ClassificationTask |
|
|
| |
| from methane_urban_datamodule import MethaneUrbanDataModule |
|
|
| |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| logging.getLogger("rasterio._env").setLevel(logging.ERROR) |
| warnings.simplefilter("ignore", NotGeoreferencedWarning) |
| warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
| def set_seed(seed: int = 42): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
| def get_inference_transforms() -> A.Compose: |
| return None |
|
|
| |
|
|
| class UrbanInference: |
| def __init__(self, args: argparse.Namespace): |
| self.args = args |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.output_dir = Path(args.output_dir) |
| self.output_dir.mkdir(parents=True, exist_ok=True) |
| |
| logger.info(f"Initializing Inference on device: {self.device}") |
| |
| self.model = self._init_model() |
| self._load_checkpoint(args.checkpoint) |
|
|
| def _init_model(self) -> nn.Module: |
| model_args = dict( |
| backbone="terramind_v1_base", |
| backbone_pretrained=False, |
| backbone_modalities=["S2L2A"], |
| backbone_merge_method="mean", |
| decoder="UperNetDecoder", |
| decoder_scale_modules=True, |
| decoder_channels=256, |
| num_classes=2, |
| head_dropout=0.3, |
| necks=[ |
| {"name": "ReshapeTokensToImage", "remove_cls_token": False}, |
| {"name": "SelectIndices", "indices": [2, 5, 8, 11]}, |
| ], |
| ) |
|
|
| task = ClassificationTask( |
| model_args=model_args, |
| model_factory="EncoderDecoderFactory", |
| loss="ce", |
| ignore_index=-1 |
| ) |
| task.configure_models() |
| return task.model.to(self.device) |
|
|
| def _load_checkpoint(self, checkpoint_path: str): |
| path = Path(checkpoint_path) |
| if not path.exists(): |
| raise FileNotFoundError(f"Checkpoint not found at {path}") |
| |
| logger.info(f"Loading weights from {path}...") |
| checkpoint = torch.load(path, map_location=self.device) |
| |
| if 'state_dict' in checkpoint: |
| state_dict = checkpoint['state_dict'] |
| else: |
| state_dict = checkpoint |
| |
| self.model.load_state_dict(state_dict, strict=False) |
| self.model.eval() |
|
|
| def run_inference(self, dataloader: DataLoader, sample_names: List[str]): |
| """ |
| Generates binary predictions and matches them with provided sample_names (folder names). |
| """ |
| sample_results = {} |
| |
| logger.info(f"Starting inference on {len(sample_names)} samples...") |
| |
| |
| name_iter = iter(sample_names) |
| |
| with torch.no_grad(): |
| for batch in tqdm(dataloader, desc="Predicting"): |
| inputs = batch['S2L2A'].to(self.device) |
| |
| |
| outputs = self.model(x={"S2L2A": inputs}) |
| probabilities = torch.softmax(outputs.output, dim=1) |
| |
| |
| predictions = torch.argmax(probabilities, dim=1) |
| batch_preds = predictions.cpu().numpy() |
| |
| |
| for pred in batch_preds: |
| try: |
| dir_name = next(name_iter) |
| sample_results[dir_name] = int(pred) |
| except StopIteration: |
| logger.error("More predictions generated than sample names provided! Check dataloader sync.") |
| break |
|
|
| |
| if len(sample_results) != len(sample_names): |
| logger.warning(f"Mismatch: Expected {len(sample_names)} results, got {len(sample_results)}.") |
|
|
| |
| self._save_results(sample_results) |
|
|
| def _save_results(self, results: Dict[str, int]): |
| csv_path = self.output_dir / "inference_predictions.csv" |
| with open(csv_path, mode='w', newline='') as f: |
| writer = csv.writer(f) |
| writer.writerow(['Sample_Directory', 'Prediction']) |
| for sample, pred in results.items(): |
| writer.writerow([sample, pred]) |
| logger.info(f"Predictions saved to {csv_path}") |
|
|
| |
|
|
| def get_dataloader_and_names(args) -> Tuple[DataLoader, List[str]]: |
| root_path = Path(args.root_dir) |
| if not root_path.exists(): |
| raise FileNotFoundError(f"Data directory {args.root_dir} not found.") |
| |
| paths = None |
| if args.excel_file: |
| try: |
| df = pd.read_excel(args.excel_file) |
| |
| paths = df['Filename'].apply(lambda x: str(x).split('_')[0]).tolist() |
| logger.info(f"Filtered {len(paths)} samples from Excel.") |
| except Exception as e: |
| logger.error(f"Error reading Excel: {e}") |
| raise |
|
|
| if paths is None: |
| |
| |
| paths = sorted([d.name for d in root_path.iterdir() if d.is_dir()]) |
| logger.info(f"Found {len(paths)} samples in directory (Sorted).") |
|
|
| |
| datamodule = MethaneUrbanDataModule( |
| data_root=args.root_dir, |
| excel_file=None, |
| batch_size=args.batch_size, |
| paths=paths, |
| train_transform=None, |
| val_transform=get_inference_transforms(), |
| test_transform=get_inference_transforms() |
| ) |
| |
| |
| datamodule.paths = paths |
| datamodule.setup(stage="test") |
| |
| |
| loader = datamodule.test_dataloader() if hasattr(datamodule, 'test_dataloader') else datamodule.train_dataloader() |
| |
| return loader, paths |
|
|
| |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Methane Urban Inference (Directory Names)") |
| |
| parser.add_argument('--root_dir', type=str, required=True, help='Root directory containing sample folders') |
| parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint (.pth)') |
| parser.add_argument('--excel_file', type=str, help='Optional Excel file to filter specific samples') |
| parser.add_argument('--output_dir', type=str, default='./inference_results', help='Directory to save results') |
| parser.add_argument('--batch_size', type=int, default=1, help='Inference batch size') |
| parser.add_argument('--seed', type=int, default=42) |
| |
| return parser.parse_args() |
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| set_seed(args.seed) |
| |
| |
| dataloader, sample_names = get_dataloader_and_names(args) |
| |
| |
| engine = UrbanInference(args) |
| engine.run_inference(dataloader, sample_names) |