Spaces:
Running
Running
File size: 3,999 Bytes
1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc 3808a54 1b91ffc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | """
AdaptFormer inference for satellite change detection.
Downloads a pre-trained AdaptFormer model from HuggingFace Hub and runs
tile-based inference on arbitrary-size image pairs, producing a binary
change mask compatible with the rest of the detection pipeline.
Falls back gracefully when torch/transformers are not installed.
"""
import logging
import os
import cv2
import numpy as np
logger = logging.getLogger(__name__)
_MODEL = None
_PROCESSOR = None
_DEVICE = None
_MODEL_ID = "deepang/adaptformer-LEVIR-CD"
_TILE_SIZE = 512
_AVAILABLE = None
def _try_import():
try:
import torch
from transformers import AutoImageProcessor, AutoModel
return torch, AutoImageProcessor, AutoModel
except ImportError:
return None, None, None
def is_model_available():
"""Check if torch and transformers are installed."""
global _AVAILABLE
if _AVAILABLE is not None:
return _AVAILABLE
torch, _, _ = _try_import()
_AVAILABLE = torch is not None
return _AVAILABLE
def _load_model():
global _MODEL, _PROCESSOR, _DEVICE
if _MODEL is not None:
return _MODEL, _PROCESSOR
torch, AutoImageProcessor, AutoModel = _try_import()
if torch is None:
raise RuntimeError("PyTorch/transformers not installed")
_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cache_dir = os.environ.get("HF_HOME", None)
logger.info("Loading AdaptFormer from %s ...", _MODEL_ID)
_PROCESSOR = AutoImageProcessor.from_pretrained(_MODEL_ID, cache_dir=cache_dir)
_MODEL = AutoModel.from_pretrained(_MODEL_ID, cache_dir=cache_dir)
_MODEL.to(_DEVICE)
_MODEL.eval()
logger.info("AdaptFormer loaded on %s", _DEVICE)
return _MODEL, _PROCESSOR
def predict_change_mask(img1, img2, threshold=0.5):
"""
Run AdaptFormer inference on two RGB numpy arrays (H, W, 3).
Images are split into overlapping tiles, predicted individually,
and stitched back into a full-resolution binary mask.
Returns (uint8 mask [0 or 255], float32 score map [0-1]).
"""
torch, _, _ = _try_import()
model, processor = _load_model()
from PIL import Image as PILImage
if img1.shape != img2.shape:
img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
h, w = img1.shape[:2]
tile = _TILE_SIZE
stride = tile * 3 // 4
pad_h = (tile - h % tile) % tile
pad_w = (tile - w % tile) % tile
if pad_h or pad_w:
img1 = np.pad(img1, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
img2 = np.pad(img2, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
ph, pw = img1.shape[:2]
score_sum = np.zeros((ph, pw), dtype=np.float32)
count = np.zeros((ph, pw), dtype=np.float32)
with torch.no_grad():
for y0 in range(0, ph - tile + 1, stride):
for x0 in range(0, pw - tile + 1, stride):
t1 = img1[y0:y0+tile, x0:x0+tile]
t2 = img2[y0:y0+tile, x0:x0+tile]
pil1 = PILImage.fromarray(t1)
pil2 = PILImage.fromarray(t2)
inputs = processor(images=(pil1, pil2), return_tensors="pt")
inputs = {k: v.to(_DEVICE) for k, v in inputs.items()}
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=1)
# Class 1 = change
prob_map = probs[0, 1].cpu().numpy()
out_h, out_w = prob_map.shape
if out_h != tile or out_w != tile:
prob_map = cv2.resize(prob_map, (tile, tile),
interpolation=cv2.INTER_LINEAR)
score_sum[y0:y0+tile, x0:x0+tile] += prob_map
count[y0:y0+tile, x0:x0+tile] += 1.0
count = np.maximum(count, 1.0)
avg_score = score_sum / count
avg_score = avg_score[:h, :w]
mask = (avg_score >= threshold).astype(np.uint8) * 255
return mask, avg_score
|