KPLabs's picture
Upload folder using huggingface_hub
97a17c2 verified
import argparse
import logging
import csv
import random
import warnings
import time
import os
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import albumentations as A
from torch.utils.data import DataLoader
from tqdm import tqdm
from rasterio.errors import NotGeoreferencedWarning
# --- CRITICAL IMPORTS ---
import terramind
from terratorch.tasks import ClassificationTask
# Local Imports
from methane_urban_datamodule import MethaneUrbanDataModule
# --- Configuration & Setup ---
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
logging.getLogger("rasterio._env").setLevel(logging.ERROR)
warnings.simplefilter("ignore", NotGeoreferencedWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
def set_seed(seed: int = 42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def get_inference_transforms() -> A.Compose:
return None
# --- Inference Class ---
class UrbanInference:
def __init__(self, args: argparse.Namespace):
self.args = args
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.output_dir = Path(args.output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Initializing Inference on device: {self.device}")
self.model = self._init_model()
self._load_checkpoint(args.checkpoint)
def _init_model(self) -> nn.Module:
model_args = dict(
backbone="terramind_v1_base",
backbone_pretrained=False,
backbone_modalities=["S2L2A"],
backbone_merge_method="mean",
decoder="UperNetDecoder",
decoder_scale_modules=True,
decoder_channels=256,
num_classes=2,
head_dropout=0.3,
necks=[
{"name": "ReshapeTokensToImage", "remove_cls_token": False},
{"name": "SelectIndices", "indices": [2, 5, 8, 11]},
],
)
task = ClassificationTask(
model_args=model_args,
model_factory="EncoderDecoderFactory",
loss="ce",
ignore_index=-1
)
task.configure_models()
return task.model.to(self.device)
def _load_checkpoint(self, checkpoint_path: str):
path = Path(checkpoint_path)
if not path.exists():
raise FileNotFoundError(f"Checkpoint not found at {path}")
logger.info(f"Loading weights from {path}...")
checkpoint = torch.load(path, map_location=self.device)
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
self.model.load_state_dict(state_dict, strict=False)
self.model.eval()
def run_inference(self, dataloader: DataLoader, sample_names: List[str]):
"""
Generates binary predictions and matches them with provided sample_names (folder names).
"""
sample_results = {}
logger.info(f"Starting inference on {len(sample_names)} samples...")
# Iterator for sample names to match sequential predictions
name_iter = iter(sample_names)
with torch.no_grad():
for batch in tqdm(dataloader, desc="Predicting"):
inputs = batch['S2L2A'].to(self.device)
# Forward Pass
outputs = self.model(x={"S2L2A": inputs})
probabilities = torch.softmax(outputs.output, dim=1)
# Get binary prediction (0 or 1)
predictions = torch.argmax(probabilities, dim=1)
batch_preds = predictions.cpu().numpy()
# Assign Directory Names to Predictions
for pred in batch_preds:
try:
dir_name = next(name_iter)
sample_results[dir_name] = int(pred)
except StopIteration:
logger.error("More predictions generated than sample names provided! Check dataloader sync.")
break
# Check if we missed any samples
if len(sample_results) != len(sample_names):
logger.warning(f"Mismatch: Expected {len(sample_names)} results, got {len(sample_results)}.")
# Save CSV
self._save_results(sample_results)
def _save_results(self, results: Dict[str, int]):
csv_path = self.output_dir / "inference_predictions.csv"
with open(csv_path, mode='w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['Sample_Directory', 'Prediction'])
for sample, pred in results.items():
writer.writerow([sample, pred])
logger.info(f"Predictions saved to {csv_path}")
# --- Data Loading ---
def get_dataloader_and_names(args) -> Tuple[DataLoader, List[str]]:
root_path = Path(args.root_dir)
if not root_path.exists():
raise FileNotFoundError(f"Data directory {args.root_dir} not found.")
paths = None
if args.excel_file:
try:
df = pd.read_excel(args.excel_file)
# Ensure we get the directory name prefix if using 'Filename' column
paths = df['Filename'].apply(lambda x: str(x).split('_')[0]).tolist()
logger.info(f"Filtered {len(paths)} samples from Excel.")
except Exception as e:
logger.error(f"Error reading Excel: {e}")
raise
if paths is None:
# Fallback to all subdirectories
# SORTING is crucial here to match the sequential dataloader
paths = sorted([d.name for d in root_path.iterdir() if d.is_dir()])
logger.info(f"Found {len(paths)} samples in directory (Sorted).")
# Initialize DataModule
datamodule = MethaneUrbanDataModule(
data_root=args.root_dir,
excel_file=None,
batch_size=args.batch_size,
paths=paths,
train_transform=None,
val_transform=get_inference_transforms(),
test_transform=get_inference_transforms()
)
# Setup for test stage
datamodule.paths = paths
datamodule.setup(stage="test")
# Get loader (prefer test_dataloader)
loader = datamodule.test_dataloader() if hasattr(datamodule, 'test_dataloader') else datamodule.train_dataloader()
return loader, paths
# --- Main Execution ---
def parse_args():
parser = argparse.ArgumentParser(description="Methane Urban Inference (Directory Names)")
parser.add_argument('--root_dir', type=str, required=True, help='Root directory containing sample folders')
parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint (.pth)')
parser.add_argument('--excel_file', type=str, help='Optional Excel file to filter specific samples')
parser.add_argument('--output_dir', type=str, default='./inference_results', help='Directory to save results')
parser.add_argument('--batch_size', type=int, default=1, help='Inference batch size')
parser.add_argument('--seed', type=int, default=42)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
set_seed(args.seed)
# 1. Prepare Data & Capture Directory Names
dataloader, sample_names = get_dataloader_and_names(args)
# 2. Run Inference
engine = UrbanInference(args)
engine.run_inference(dataloader, sample_names)