Spaces:
Running
Running
File size: 4,954 Bytes
1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 ce994d8 3808a54 ba4abf7 1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc ba4abf7 1b91ffc 3808a54 ba4abf7 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 ba4abf7 3808a54 1b91ffc ba4abf7 1b91ffc 3808a54 ba4abf7 1b91ffc 3808a54 1b91ffc ba4abf7 1b91ffc ce994d8 1b91ffc 3808a54 ce994d8 1b91ffc ce994d8 1b91ffc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | """
AdaptFormer inference for satellite change detection.
Downloads a pre-trained AdaptFormer model from HuggingFace Hub and runs
tile-based inference on arbitrary-size image pairs, producing a binary
change mask compatible with the rest of the detection pipeline.
Falls back gracefully when torch/transformers are not installed.
"""
import logging
import os
import cv2
import numpy as np
logger = logging.getLogger(__name__)
_MODEL = None
_PROCESSOR = None
_DEVICE = None
_MODEL_ID = "deepang/adaptformer-LEVIR-CD"
_TILE_SIZE = 256 # LEVIR-CD native patch size
_AVAILABLE = None
_LOAD_FAILED = False
def _try_import():
try:
import torch
from transformers import AutoImageProcessor, AutoModel
return torch, AutoImageProcessor, AutoModel
except ImportError:
return None, None, None
def _load_model():
global _MODEL, _PROCESSOR, _DEVICE, _AVAILABLE, _LOAD_FAILED
if _MODEL is not None:
return _MODEL, _PROCESSOR
if _LOAD_FAILED:
raise RuntimeError("AdaptFormer load previously failed")
torch, AutoImageProcessor, AutoModel = _try_import()
if torch is None:
raise RuntimeError("PyTorch/transformers not installed")
_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cache_dir = os.environ.get("HF_HOME", None)
logger.info("Loading AdaptFormer from %s ...", _MODEL_ID)
try:
_PROCESSOR = AutoImageProcessor.from_pretrained(
_MODEL_ID, cache_dir=cache_dir, trust_remote_code=True)
_MODEL = AutoModel.from_pretrained(
_MODEL_ID, cache_dir=cache_dir, trust_remote_code=True)
_MODEL.to(_DEVICE)
_MODEL.eval()
_AVAILABLE = True
logger.info("AdaptFormer loaded on %s", _DEVICE)
except Exception as exc:
_LOAD_FAILED = True
_AVAILABLE = False
logger.error("AdaptFormer load failed: %s", exc)
raise
return _MODEL, _PROCESSOR
def is_model_available():
"""True only if PyTorch is installed and the model loads successfully."""
global _AVAILABLE
if _AVAILABLE is not None:
return _AVAILABLE
if _LOAD_FAILED:
return False
try:
_load_model()
return True
except Exception:
return False
def preload_model():
"""Warm-load AdaptFormer at app startup (best-effort)."""
try:
_load_model()
logger.info("AdaptFormer preload complete")
return True
except Exception as exc:
logger.warning("AdaptFormer preload skipped: %s", exc)
return False
def predict_change_mask(img1, img2, threshold=0.5):
"""
Run AdaptFormer inference on two RGB numpy arrays (H, W, 3).
Returns (uint8 mask [0 or 255], float32 score map [0-1]).
Use threshold > 1.0 to obtain score map only (empty mask).
"""
torch, _, _ = _try_import()
model, processor = _load_model()
from PIL import Image as PILImage
if img1.shape != img2.shape:
img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
h, w = img1.shape[:2]
tile = _TILE_SIZE
overlap = tile // 4
stride = tile - overlap
pad_h = (tile - h % tile) % tile
pad_w = (tile - w % tile) % tile
if pad_h or pad_w:
img1 = np.pad(img1, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
img2 = np.pad(img2, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
ph, pw = img1.shape[:2]
score_sum = np.zeros((ph, pw), dtype=np.float32)
count = np.zeros((ph, pw), dtype=np.float32)
ramp = np.linspace(0, 1, overlap)
flat = np.ones(tile - 2 * overlap)
profile = np.concatenate([ramp, flat, ramp[::-1]])
weight_2d = np.outer(profile, profile).astype(np.float32)
with torch.no_grad():
for y0 in range(0, ph - tile + 1, stride):
for x0 in range(0, pw - tile + 1, stride):
t1 = img1[y0:y0+tile, x0:x0+tile]
t2 = img2[y0:y0+tile, x0:x0+tile]
pil1 = PILImage.fromarray(t1)
pil2 = PILImage.fromarray(t2)
inputs = processor(images=(pil1, pil2), return_tensors="pt")
inputs = {k: v.to(_DEVICE) for k, v in inputs.items()}
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=1)
prob_map = probs[0, 1].cpu().numpy()
out_h, out_w = prob_map.shape
if out_h != tile or out_w != tile:
prob_map = cv2.resize(prob_map, (tile, tile),
interpolation=cv2.INTER_LINEAR)
score_sum[y0:y0+tile, x0:x0+tile] += prob_map * weight_2d
count[y0:y0+tile, x0:x0+tile] += weight_2d
count = np.maximum(count, 1e-6)
avg_score = score_sum / count
avg_score = avg_score[:h, :w]
mask = (avg_score >= threshold).astype(np.uint8) * 255
return mask, avg_score
|