""" AdaptFormer inference for satellite change detection. Downloads a pre-trained AdaptFormer model from HuggingFace Hub and runs tile-based inference on arbitrary-size image pairs, producing a binary change mask compatible with the rest of the detection pipeline. Falls back gracefully when torch/transformers are not installed. """ import logging import os import cv2 import numpy as np logger = logging.getLogger(__name__) _MODEL = None _PROCESSOR = None _DEVICE = None _MODEL_ID = "deepang/adaptformer-LEVIR-CD" _TILE_SIZE = 256 # LEVIR-CD native patch size _AVAILABLE = None _LOAD_FAILED = False def _try_import(): try: import torch from transformers import AutoImageProcessor, AutoModel return torch, AutoImageProcessor, AutoModel except ImportError: return None, None, None def _load_model(): global _MODEL, _PROCESSOR, _DEVICE, _AVAILABLE, _LOAD_FAILED if _MODEL is not None: return _MODEL, _PROCESSOR if _LOAD_FAILED: raise RuntimeError("AdaptFormer load previously failed") torch, AutoImageProcessor, AutoModel = _try_import() if torch is None: raise RuntimeError("PyTorch/transformers not installed") _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") cache_dir = os.environ.get("HF_HOME", None) logger.info("Loading AdaptFormer from %s ...", _MODEL_ID) try: _PROCESSOR = AutoImageProcessor.from_pretrained( _MODEL_ID, cache_dir=cache_dir, trust_remote_code=True) _MODEL = AutoModel.from_pretrained( _MODEL_ID, cache_dir=cache_dir, trust_remote_code=True) _MODEL.to(_DEVICE) _MODEL.eval() _AVAILABLE = True logger.info("AdaptFormer loaded on %s", _DEVICE) except Exception as exc: _LOAD_FAILED = True _AVAILABLE = False logger.error("AdaptFormer load failed: %s", exc) raise return _MODEL, _PROCESSOR def is_model_available(): """True only if PyTorch is installed and the model loads successfully.""" global _AVAILABLE if _AVAILABLE is not None: return _AVAILABLE if _LOAD_FAILED: return False try: _load_model() return True except Exception: return False def preload_model(): """Warm-load AdaptFormer at app startup (best-effort).""" try: _load_model() logger.info("AdaptFormer preload complete") return True except Exception as exc: logger.warning("AdaptFormer preload skipped: %s", exc) return False def predict_change_mask(img1, img2, threshold=0.5): """ Run AdaptFormer inference on two RGB numpy arrays (H, W, 3). Returns (uint8 mask [0 or 255], float32 score map [0-1]). Use threshold > 1.0 to obtain score map only (empty mask). """ torch, _, _ = _try_import() model, processor = _load_model() from PIL import Image as PILImage if img1.shape != img2.shape: img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0])) h, w = img1.shape[:2] tile = _TILE_SIZE overlap = tile // 4 stride = tile - overlap pad_h = (tile - h % tile) % tile pad_w = (tile - w % tile) % tile if pad_h or pad_w: img1 = np.pad(img1, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect") img2 = np.pad(img2, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect") ph, pw = img1.shape[:2] score_sum = np.zeros((ph, pw), dtype=np.float32) count = np.zeros((ph, pw), dtype=np.float32) ramp = np.linspace(0, 1, overlap) flat = np.ones(tile - 2 * overlap) profile = np.concatenate([ramp, flat, ramp[::-1]]) weight_2d = np.outer(profile, profile).astype(np.float32) with torch.no_grad(): for y0 in range(0, ph - tile + 1, stride): for x0 in range(0, pw - tile + 1, stride): t1 = img1[y0:y0+tile, x0:x0+tile] t2 = img2[y0:y0+tile, x0:x0+tile] pil1 = PILImage.fromarray(t1) pil2 = PILImage.fromarray(t2) inputs = processor(images=(pil1, pil2), return_tensors="pt") inputs = {k: v.to(_DEVICE) for k, v in inputs.items()} outputs = model(**inputs) logits = outputs.logits probs = torch.softmax(logits, dim=1) prob_map = probs[0, 1].cpu().numpy() out_h, out_w = prob_map.shape if out_h != tile or out_w != tile: prob_map = cv2.resize(prob_map, (tile, tile), interpolation=cv2.INTER_LINEAR) score_sum[y0:y0+tile, x0:x0+tile] += prob_map * weight_2d count[y0:y0+tile, x0:x0+tile] += weight_2d count = np.maximum(count, 1e-6) avg_score = score_sum / count avg_score = avg_score[:h, :w] mask = (avg_score >= threshold).astype(np.uint8) * 255 return mask, avg_score