File size: 4,954 Bytes
1b91ffc
3808a54
1b91ffc
3808a54
1b91ffc
 
 
3808a54
1b91ffc
 
 
 
 
 
 
 
 
 
3808a54
 
 
ce994d8
3808a54
ba4abf7
1b91ffc
 
3808a54
1b91ffc
 
3808a54
 
1b91ffc
3808a54
1b91ffc
 
 
ba4abf7
1b91ffc
3808a54
ba4abf7
 
3808a54
 
1b91ffc
3808a54
1b91ffc
3808a54
1b91ffc
3808a54
 
ba4abf7
 
 
 
 
 
 
 
 
 
 
 
 
 
3808a54
1b91ffc
 
ba4abf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b91ffc
 
3808a54
 
ba4abf7
1b91ffc
3808a54
 
 
1b91ffc
 
 
 
 
 
ba4abf7
 
1b91ffc
 
 
 
 
 
 
 
 
 
 
ce994d8
 
 
 
 
1b91ffc
 
 
3808a54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce994d8
 
1b91ffc
ce994d8
1b91ffc
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
AdaptFormer inference for satellite change detection.

Downloads a pre-trained AdaptFormer model from HuggingFace Hub and runs
tile-based inference on arbitrary-size image pairs, producing a binary
change mask compatible with the rest of the detection pipeline.

Falls back gracefully when torch/transformers are not installed.
"""
import logging
import os

import cv2
import numpy as np

logger = logging.getLogger(__name__)

_MODEL = None
_PROCESSOR = None
_DEVICE = None
_MODEL_ID = "deepang/adaptformer-LEVIR-CD"
_TILE_SIZE = 256  # LEVIR-CD native patch size
_AVAILABLE = None
_LOAD_FAILED = False


def _try_import():
    try:
        import torch
        from transformers import AutoImageProcessor, AutoModel
        return torch, AutoImageProcessor, AutoModel
    except ImportError:
        return None, None, None


def _load_model():
    global _MODEL, _PROCESSOR, _DEVICE, _AVAILABLE, _LOAD_FAILED
    if _MODEL is not None:
        return _MODEL, _PROCESSOR
    if _LOAD_FAILED:
        raise RuntimeError("AdaptFormer load previously failed")

    torch, AutoImageProcessor, AutoModel = _try_import()
    if torch is None:
        raise RuntimeError("PyTorch/transformers not installed")

    _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    cache_dir = os.environ.get("HF_HOME", None)
    logger.info("Loading AdaptFormer from %s ...", _MODEL_ID)
    try:
        _PROCESSOR = AutoImageProcessor.from_pretrained(
            _MODEL_ID, cache_dir=cache_dir, trust_remote_code=True)
        _MODEL = AutoModel.from_pretrained(
            _MODEL_ID, cache_dir=cache_dir, trust_remote_code=True)
        _MODEL.to(_DEVICE)
        _MODEL.eval()
        _AVAILABLE = True
        logger.info("AdaptFormer loaded on %s", _DEVICE)
    except Exception as exc:
        _LOAD_FAILED = True
        _AVAILABLE = False
        logger.error("AdaptFormer load failed: %s", exc)
        raise
    return _MODEL, _PROCESSOR


def is_model_available():
    """True only if PyTorch is installed and the model loads successfully."""
    global _AVAILABLE
    if _AVAILABLE is not None:
        return _AVAILABLE
    if _LOAD_FAILED:
        return False
    try:
        _load_model()
        return True
    except Exception:
        return False


def preload_model():
    """Warm-load AdaptFormer at app startup (best-effort)."""
    try:
        _load_model()
        logger.info("AdaptFormer preload complete")
        return True
    except Exception as exc:
        logger.warning("AdaptFormer preload skipped: %s", exc)
        return False


def predict_change_mask(img1, img2, threshold=0.5):
    """
    Run AdaptFormer inference on two RGB numpy arrays (H, W, 3).
    Returns (uint8 mask [0 or 255], float32 score map [0-1]).
    Use threshold > 1.0 to obtain score map only (empty mask).
    """
    torch, _, _ = _try_import()
    model, processor = _load_model()
    from PIL import Image as PILImage

    if img1.shape != img2.shape:
        img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))

    h, w = img1.shape[:2]
    tile = _TILE_SIZE
    overlap = tile // 4
    stride = tile - overlap

    pad_h = (tile - h % tile) % tile
    pad_w = (tile - w % tile) % tile
    if pad_h or pad_w:
        img1 = np.pad(img1, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
        img2 = np.pad(img2, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")

    ph, pw = img1.shape[:2]
    score_sum = np.zeros((ph, pw), dtype=np.float32)
    count = np.zeros((ph, pw), dtype=np.float32)

    ramp = np.linspace(0, 1, overlap)
    flat = np.ones(tile - 2 * overlap)
    profile = np.concatenate([ramp, flat, ramp[::-1]])
    weight_2d = np.outer(profile, profile).astype(np.float32)

    with torch.no_grad():
        for y0 in range(0, ph - tile + 1, stride):
            for x0 in range(0, pw - tile + 1, stride):
                t1 = img1[y0:y0+tile, x0:x0+tile]
                t2 = img2[y0:y0+tile, x0:x0+tile]

                pil1 = PILImage.fromarray(t1)
                pil2 = PILImage.fromarray(t2)

                inputs = processor(images=(pil1, pil2), return_tensors="pt")
                inputs = {k: v.to(_DEVICE) for k, v in inputs.items()}

                outputs = model(**inputs)
                logits = outputs.logits
                probs = torch.softmax(logits, dim=1)

                prob_map = probs[0, 1].cpu().numpy()

                out_h, out_w = prob_map.shape
                if out_h != tile or out_w != tile:
                    prob_map = cv2.resize(prob_map, (tile, tile),
                                          interpolation=cv2.INTER_LINEAR)

                score_sum[y0:y0+tile, x0:x0+tile] += prob_map * weight_2d
                count[y0:y0+tile, x0:x0+tile] += weight_2d

    count = np.maximum(count, 1e-6)
    avg_score = score_sum / count
    avg_score = avg_score[:h, :w]

    mask = (avg_score >= threshold).astype(np.uint8) * 255
    return mask, avg_score