File size: 8,159 Bytes
9412d2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import os
import gradio as gr
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
import rasterio
import torch.nn as nn
import torch.nn.functional as F
from timm.models.vision_transformer import VisionTransformer
# Model Components
class FeatureDifferenceModule(nn.Module):
def __init__(self, in_channels):
super(FeatureDifferenceModule, self).__init__()
self.conv = nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1)
self.bn = nn.BatchNorm2d(in_channels // 2)
self.relu = nn.ReLU()
def forward(self, feat1, feat2):
x = torch.abs(feat1 - feat2)
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class DeconvDecoder(nn.Module):
def __init__(self, in_channels, num_classes):
super(DeconvDecoder, self).__init__()
self.deconv1 = nn.ConvTranspose2d(in_channels // 2, 128, kernel_size=3, stride=2, padding=1, output_padding=1)
self.deconv2 = nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
self.deconv3 = nn.ConvTranspose2d(32, num_classes, kernel_size=3, stride=2, padding=1, output_padding=1)
def forward(self, x):
x = F.relu(self.deconv1(x))
x = F.relu(self.deconv2(x))
x = self.deconv3(x)
return x
class ChangeFormer(nn.Module):
def __init__(self, img_size=256, num_classes=1):
super(ChangeFormer, self).__init__()
self.encoder = VisionTransformer(
img_size=img_size,
patch_size=16,
embed_dim=384,
depth=4,
num_heads=6,
in_chans=4,
)
self.feature_diff = FeatureDifferenceModule(in_channels=384)
self.decoder = DeconvDecoder(in_channels=384, num_classes=num_classes)
self.img_size = img_size
self.patch_size = 16
def forward(self, img1, img2):
feat1 = self.encoder.forward_features(img1)
feat2 = self.encoder.forward_features(img2)
feat1 = feat1[:, 1:, :]
feat2 = feat2[:, 1:, :]
B, N, C = feat1.shape
h = w = self.img_size // self.patch_size
feat1 = feat1.transpose(1, 2).view(B, C, h, w)
feat2 = feat2.transpose(1, 2).view(B, C, h, w)
diff = self.feature_diff(feat1, feat2)
out = self.decoder(diff)
out = F.interpolate(out, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)
return out
# Model Initialization
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ChangeFormer(num_classes=1).to(device)
print("ChangeFormer Model Initialized!")
# Load model weights
model_path = "/content/drive/MyDrive/DeforestationApp/models/best_model.pth"
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found at {model_path}.")
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
PATCH_SIZE = 256
transform = transforms.ToTensor()
def read_patch_4band(path, x, y, size=PATCH_SIZE):
with rasterio.open(path) as src:
band_indices = [i for i in range(1, min(src.count, 4) + 1)] # Bands 1–4
patch = src.read(band_indices, window=rasterio.windows.Window(x, y, size, size))
# Optional: cloud masking if band 8 (SCL) is present
if src.count >= 8:
scl = src.read(8, window=rasterio.windows.Window(x, y, size, size))
cloud_mask = (scl == 3) | (scl == 8) | (scl == 9)
patch[:, cloud_mask] = 0
patch = np.transpose(patch, (1, 2, 0))
return patch
def get_patch_coords(path, patch_size=PATCH_SIZE):
with rasterio.open(path) as src:
w, h = src.width, src.height
coords = [(x, y) for y in range(0, h, patch_size)
for x in range(0, w, patch_size)
if x + patch_size <= w and y + patch_size <= h]
return coords, (w, h)
def predict_on_large_4band_tifs(path1, path2):
coords, full_size = get_patch_coords(path1)
preds = []
for i in range(0, len(coords), 4): # Batch size of 4
batch_coords = coords[i:i+4]
batch_t1, batch_t2 = [], []
for x, y in batch_coords:
patch1 = read_patch_4band(path1, x, y)
patch2 = read_patch_4band(path2, x, y)
batch_t1.append(transform(patch1))
batch_t2.append(transform(patch2))
t1 = torch.stack(batch_t1).to(device)
t2 = torch.stack(batch_t2).to(device)
with torch.no_grad():
pred = model(t1, t2)
pred = torch.sigmoid(pred).squeeze().cpu().numpy()
for p, (x, y) in zip(pred, batch_coords):
pred_binary = (p > 0.5).astype(np.uint8)
preds.append((pred_binary, (x, y)))
return preds, full_size
def stitch_patches(preds, full_size, patch_size=PATCH_SIZE):
stitched = np.zeros((full_size[1], full_size[0]), dtype=np.uint8)
for patch, (x, y) in preds:
stitched[y:y+patch_size, x:x+patch_size] = patch
return stitched
def normalize_rgb(path):
with rasterio.open(path) as src:
rgb = src.read([1, 2, 3]).astype(np.float32)
rgb = np.transpose(rgb, (1, 2, 0))
mask = np.any(np.isnan(rgb), axis=-1) | np.all(rgb == 0, axis=-1)
rgb[mask] = np.nan
p2 = np.nanpercentile(rgb, 2)
p98 = np.nanpercentile(rgb, 98)
if p98 - p2 < 1e-5:
rgb = np.clip(rgb / 255.0, 0, 1)
else:
rgb = np.clip((rgb - p2) / (p98 - p2), 0, 1)
rgb = np.nan_to_num(rgb)
return rgb
def overlay_mask(rgb_img, mask, alpha=0.4):
mask = mask.astype(np.float32)
color_mask = np.zeros_like(rgb_img)
color_mask[..., 0] = mask
blended = (1 - alpha) * rgb_img + alpha * color_mask
blended = np.clip(blended, 0, 1)
return (blended * 255).astype(np.uint8)
def generate_comment(mask):
changed_pixels = np.count_nonzero(mask)
total_pixels = mask.size
percent = (changed_pixels / total_pixels) * 100
if percent > 5:
return f"Significant change detected: {percent:.2f}%"
elif percent > 1:
return f"Minor change detected: {percent:.2f}%"
elif percent > 0:
return f"Minimal change: {percent:.2f}%"
else:
return "No change detected."
def clear_outputs():
return None, None, None, "Please upload new images to generate results."
def predict_change(file1, file2):
try:
path1, path2 = file1.name, file2.name
with rasterio.open(path1) as src:
if src.count < 4:
raise ValueError("Input image must have at least 4 bands (RGB+NIR).")
preds, full_size = predict_on_large_4band_tifs(path1, path2)
mask = stitch_patches(preds, full_size)
rgb = normalize_rgb(path2)
overlay = overlay_mask(rgb, mask)
return (
Image.fromarray((rgb * 255).astype(np.uint8)),
Image.fromarray(overlay),
Image.fromarray((mask * 255).astype(np.uint8)),
generate_comment(mask)
)
except Exception as e:
return None, None, None, f"Error: {str(e)}"
# ==========================
# Gradio UI
# ==========================
with gr.Blocks() as demo:
gr.Markdown("### UPLOAD INSTRUCTIONS:\n- **First Image** → OLDER image (earlier date)\n- **Second Image** → NEWER image (later date)\n\n> Both images must have **at least 4 bands (RGB + NIR)**.")
with gr.Row():
file1 = gr.File(label=" First Image (OLDER)", file_types=[".tif"])
file2 = gr.File(label=" Second Image (NEWER)", file_types=[".tif"])
with gr.Row():
output1 = gr.Image(label="Raw Second Image RGB")
output2 = gr.Image(label="Overlay with Prediction")
output3 = gr.Image(label="Binary Change Mask")
output4 = gr.Textbox(label="Auto-generated Comment")
file1.upload(clear_outputs, None, [output1, output2, output3, output4])
file2.upload(clear_outputs, None, [output1, output2, output3, output4])
btn = gr.Button("Submit")
btn.click(predict_change, inputs=[file1, file2], outputs=[output1, output2, output3, output4])
demo.launch() |