import os import gradio as gr import torch import numpy as np from torchvision import transforms from PIL import Image import rasterio import torch.nn as nn import torch.nn.functional as F from timm.models.vision_transformer import VisionTransformer # Model Components class FeatureDifferenceModule(nn.Module): def __init__(self, in_channels): super(FeatureDifferenceModule, self).__init__() self.conv = nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1) self.bn = nn.BatchNorm2d(in_channels // 2) self.relu = nn.ReLU() def forward(self, feat1, feat2): x = torch.abs(feat1 - feat2) x = self.conv(x) x = self.bn(x) x = self.relu(x) return x class DeconvDecoder(nn.Module): def __init__(self, in_channels, num_classes): super(DeconvDecoder, self).__init__() self.deconv1 = nn.ConvTranspose2d(in_channels // 2, 128, kernel_size=3, stride=2, padding=1, output_padding=1) self.deconv2 = nn.ConvTranspose2d(128, 32, kernel_size=3, stride=2, padding=1, output_padding=1) self.deconv3 = nn.ConvTranspose2d(32, num_classes, kernel_size=3, stride=2, padding=1, output_padding=1) def forward(self, x): x = F.relu(self.deconv1(x)) x = F.relu(self.deconv2(x)) x = self.deconv3(x) return x class ChangeFormer(nn.Module): def __init__(self, img_size=256, num_classes=1): super(ChangeFormer, self).__init__() self.encoder = VisionTransformer( img_size=img_size, patch_size=16, embed_dim=384, depth=4, num_heads=6, in_chans=4, ) self.feature_diff = FeatureDifferenceModule(in_channels=384) self.decoder = DeconvDecoder(in_channels=384, num_classes=num_classes) self.img_size = img_size self.patch_size = 16 def forward(self, img1, img2): feat1 = self.encoder.forward_features(img1) feat2 = self.encoder.forward_features(img2) feat1 = feat1[:, 1:, :] feat2 = feat2[:, 1:, :] B, N, C = feat1.shape h = w = self.img_size // self.patch_size feat1 = feat1.transpose(1, 2).view(B, C, h, w) feat2 = feat2.transpose(1, 2).view(B, C, h, w) diff = self.feature_diff(feat1, feat2) out = self.decoder(diff) out = F.interpolate(out, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False) return out # Model Initialization device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = ChangeFormer(num_classes=1).to(device) print("ChangeFormer Model Initialized!") # Load model weights model_path = "/content/drive/MyDrive/DeforestationApp/models/best_model.pth" if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found at {model_path}.") model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() PATCH_SIZE = 256 transform = transforms.ToTensor() def read_patch_4band(path, x, y, size=PATCH_SIZE): with rasterio.open(path) as src: band_indices = [i for i in range(1, min(src.count, 4) + 1)] # Bands 1–4 patch = src.read(band_indices, window=rasterio.windows.Window(x, y, size, size)) # Optional: cloud masking if band 8 (SCL) is present if src.count >= 8: scl = src.read(8, window=rasterio.windows.Window(x, y, size, size)) cloud_mask = (scl == 3) | (scl == 8) | (scl == 9) patch[:, cloud_mask] = 0 patch = np.transpose(patch, (1, 2, 0)) return patch def get_patch_coords(path, patch_size=PATCH_SIZE): with rasterio.open(path) as src: w, h = src.width, src.height coords = [(x, y) for y in range(0, h, patch_size) for x in range(0, w, patch_size) if x + patch_size <= w and y + patch_size <= h] return coords, (w, h) def predict_on_large_4band_tifs(path1, path2): coords, full_size = get_patch_coords(path1) preds = [] for i in range(0, len(coords), 4): # Batch size of 4 batch_coords = coords[i:i+4] batch_t1, batch_t2 = [], [] for x, y in batch_coords: patch1 = read_patch_4band(path1, x, y) patch2 = read_patch_4band(path2, x, y) batch_t1.append(transform(patch1)) batch_t2.append(transform(patch2)) t1 = torch.stack(batch_t1).to(device) t2 = torch.stack(batch_t2).to(device) with torch.no_grad(): pred = model(t1, t2) pred = torch.sigmoid(pred).squeeze().cpu().numpy() for p, (x, y) in zip(pred, batch_coords): pred_binary = (p > 0.5).astype(np.uint8) preds.append((pred_binary, (x, y))) return preds, full_size def stitch_patches(preds, full_size, patch_size=PATCH_SIZE): stitched = np.zeros((full_size[1], full_size[0]), dtype=np.uint8) for patch, (x, y) in preds: stitched[y:y+patch_size, x:x+patch_size] = patch return stitched def normalize_rgb(path): with rasterio.open(path) as src: rgb = src.read([1, 2, 3]).astype(np.float32) rgb = np.transpose(rgb, (1, 2, 0)) mask = np.any(np.isnan(rgb), axis=-1) | np.all(rgb == 0, axis=-1) rgb[mask] = np.nan p2 = np.nanpercentile(rgb, 2) p98 = np.nanpercentile(rgb, 98) if p98 - p2 < 1e-5: rgb = np.clip(rgb / 255.0, 0, 1) else: rgb = np.clip((rgb - p2) / (p98 - p2), 0, 1) rgb = np.nan_to_num(rgb) return rgb def overlay_mask(rgb_img, mask, alpha=0.4): mask = mask.astype(np.float32) color_mask = np.zeros_like(rgb_img) color_mask[..., 0] = mask blended = (1 - alpha) * rgb_img + alpha * color_mask blended = np.clip(blended, 0, 1) return (blended * 255).astype(np.uint8) def generate_comment(mask): changed_pixels = np.count_nonzero(mask) total_pixels = mask.size percent = (changed_pixels / total_pixels) * 100 if percent > 5: return f"Significant change detected: {percent:.2f}%" elif percent > 1: return f"Minor change detected: {percent:.2f}%" elif percent > 0: return f"Minimal change: {percent:.2f}%" else: return "No change detected." def clear_outputs(): return None, None, None, "Please upload new images to generate results." def predict_change(file1, file2): try: path1, path2 = file1.name, file2.name with rasterio.open(path1) as src: if src.count < 4: raise ValueError("Input image must have at least 4 bands (RGB+NIR).") preds, full_size = predict_on_large_4band_tifs(path1, path2) mask = stitch_patches(preds, full_size) rgb = normalize_rgb(path2) overlay = overlay_mask(rgb, mask) return ( Image.fromarray((rgb * 255).astype(np.uint8)), Image.fromarray(overlay), Image.fromarray((mask * 255).astype(np.uint8)), generate_comment(mask) ) except Exception as e: return None, None, None, f"Error: {str(e)}" # ========================== # Gradio UI # ========================== with gr.Blocks() as demo: gr.Markdown("### UPLOAD INSTRUCTIONS:\n- **First Image** → OLDER image (earlier date)\n- **Second Image** → NEWER image (later date)\n\n> Both images must have **at least 4 bands (RGB + NIR)**.") with gr.Row(): file1 = gr.File(label=" First Image (OLDER)", file_types=[".tif"]) file2 = gr.File(label=" Second Image (NEWER)", file_types=[".tif"]) with gr.Row(): output1 = gr.Image(label="Raw Second Image RGB") output2 = gr.Image(label="Overlay with Prediction") output3 = gr.Image(label="Binary Change Mask") output4 = gr.Textbox(label="Auto-generated Comment") file1.upload(clear_outputs, None, [output1, output2, output3, output4]) file2.upload(clear_outputs, None, [output1, output2, output3, output4]) btn = gr.Button("Submit") btn.click(predict_change, inputs=[file1, file2], outputs=[output1, output2, output3, output4]) demo.launch()