Spaces:
Running
Running
Fix detection precision: gated fusion, strict registration, calibrated thresholds, preload model
ba4abf7 | """ | |
| AdaptFormer inference for satellite change detection. | |
| Downloads a pre-trained AdaptFormer model from HuggingFace Hub and runs | |
| tile-based inference on arbitrary-size image pairs, producing a binary | |
| change mask compatible with the rest of the detection pipeline. | |
| Falls back gracefully when torch/transformers are not installed. | |
| """ | |
| import logging | |
| import os | |
| import cv2 | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| _MODEL = None | |
| _PROCESSOR = None | |
| _DEVICE = None | |
| _MODEL_ID = "deepang/adaptformer-LEVIR-CD" | |
| _TILE_SIZE = 256 # LEVIR-CD native patch size | |
| _AVAILABLE = None | |
| _LOAD_FAILED = False | |
| def _try_import(): | |
| try: | |
| import torch | |
| from transformers import AutoImageProcessor, AutoModel | |
| return torch, AutoImageProcessor, AutoModel | |
| except ImportError: | |
| return None, None, None | |
| def _load_model(): | |
| global _MODEL, _PROCESSOR, _DEVICE, _AVAILABLE, _LOAD_FAILED | |
| if _MODEL is not None: | |
| return _MODEL, _PROCESSOR | |
| if _LOAD_FAILED: | |
| raise RuntimeError("AdaptFormer load previously failed") | |
| torch, AutoImageProcessor, AutoModel = _try_import() | |
| if torch is None: | |
| raise RuntimeError("PyTorch/transformers not installed") | |
| _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| cache_dir = os.environ.get("HF_HOME", None) | |
| logger.info("Loading AdaptFormer from %s ...", _MODEL_ID) | |
| try: | |
| _PROCESSOR = AutoImageProcessor.from_pretrained( | |
| _MODEL_ID, cache_dir=cache_dir, trust_remote_code=True) | |
| _MODEL = AutoModel.from_pretrained( | |
| _MODEL_ID, cache_dir=cache_dir, trust_remote_code=True) | |
| _MODEL.to(_DEVICE) | |
| _MODEL.eval() | |
| _AVAILABLE = True | |
| logger.info("AdaptFormer loaded on %s", _DEVICE) | |
| except Exception as exc: | |
| _LOAD_FAILED = True | |
| _AVAILABLE = False | |
| logger.error("AdaptFormer load failed: %s", exc) | |
| raise | |
| return _MODEL, _PROCESSOR | |
| def is_model_available(): | |
| """True only if PyTorch is installed and the model loads successfully.""" | |
| global _AVAILABLE | |
| if _AVAILABLE is not None: | |
| return _AVAILABLE | |
| if _LOAD_FAILED: | |
| return False | |
| try: | |
| _load_model() | |
| return True | |
| except Exception: | |
| return False | |
| def preload_model(): | |
| """Warm-load AdaptFormer at app startup (best-effort).""" | |
| try: | |
| _load_model() | |
| logger.info("AdaptFormer preload complete") | |
| return True | |
| except Exception as exc: | |
| logger.warning("AdaptFormer preload skipped: %s", exc) | |
| return False | |
| def predict_change_mask(img1, img2, threshold=0.5): | |
| """ | |
| Run AdaptFormer inference on two RGB numpy arrays (H, W, 3). | |
| Returns (uint8 mask [0 or 255], float32 score map [0-1]). | |
| Use threshold > 1.0 to obtain score map only (empty mask). | |
| """ | |
| torch, _, _ = _try_import() | |
| model, processor = _load_model() | |
| from PIL import Image as PILImage | |
| if img1.shape != img2.shape: | |
| img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0])) | |
| h, w = img1.shape[:2] | |
| tile = _TILE_SIZE | |
| overlap = tile // 4 | |
| stride = tile - overlap | |
| pad_h = (tile - h % tile) % tile | |
| pad_w = (tile - w % tile) % tile | |
| if pad_h or pad_w: | |
| img1 = np.pad(img1, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect") | |
| img2 = np.pad(img2, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect") | |
| ph, pw = img1.shape[:2] | |
| score_sum = np.zeros((ph, pw), dtype=np.float32) | |
| count = np.zeros((ph, pw), dtype=np.float32) | |
| ramp = np.linspace(0, 1, overlap) | |
| flat = np.ones(tile - 2 * overlap) | |
| profile = np.concatenate([ramp, flat, ramp[::-1]]) | |
| weight_2d = np.outer(profile, profile).astype(np.float32) | |
| with torch.no_grad(): | |
| for y0 in range(0, ph - tile + 1, stride): | |
| for x0 in range(0, pw - tile + 1, stride): | |
| t1 = img1[y0:y0+tile, x0:x0+tile] | |
| t2 = img2[y0:y0+tile, x0:x0+tile] | |
| pil1 = PILImage.fromarray(t1) | |
| pil2 = PILImage.fromarray(t2) | |
| inputs = processor(images=(pil1, pil2), return_tensors="pt") | |
| inputs = {k: v.to(_DEVICE) for k, v in inputs.items()} | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probs = torch.softmax(logits, dim=1) | |
| prob_map = probs[0, 1].cpu().numpy() | |
| out_h, out_w = prob_map.shape | |
| if out_h != tile or out_w != tile: | |
| prob_map = cv2.resize(prob_map, (tile, tile), | |
| interpolation=cv2.INTER_LINEAR) | |
| score_sum[y0:y0+tile, x0:x0+tile] += prob_map * weight_2d | |
| count[y0:y0+tile, x0:x0+tile] += weight_2d | |
| count = np.maximum(count, 1e-6) | |
| avg_score = score_sum / count | |
| avg_score = avg_score[:h, :w] | |
| mask = (avg_score >= threshold).astype(np.uint8) * 255 | |
| return mask, avg_score | |