""" SpaceNet Building Detection """ import math import os import cv2 import gradio as gr import numpy as np import torch import torch.nn.functional as F import urllib.request from huggingface_hub import hf_hub_download import rasterio from PIL import Image from models.unet import UNet MODEL_REPO = "harshinde/spacenet" MODEL_FILENAME = "model.safetensors" CHANNEL_MEAN = np.array([71.2274, 78.3385, 56.2296], dtype=np.float32) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL = None def load_model() -> UNet: global MODEL if MODEL is not None: return MODEL weights_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME) model = UNet(in_channels=3, num_classes=2, base_features=64, depth=4, dropout=0.0) try: from safetensors.torch import load_file state_dict = load_file(weights_path, device=str(DEVICE)) except ImportError: state_dict = torch.load(weights_path, map_location=DEVICE, weights_only=False) if "model_state_dict" in state_dict: model.load_state_dict(state_dict["model_state_dict"]) else: model.load_state_dict(state_dict) model = model.to(DEVICE).eval() MODEL = model return MODEL @torch.no_grad() def predict_tile(model: UNet, image: np.ndarray) -> np.ndarray: h, w = image.shape[:2] h_pad = int(math.ceil(h / 16) * 16) w_pad = int(math.ceil(w / 16) * 16) py1 = (h_pad - h) // 2 px1 = (w_pad - w) // 2 py2 = h_pad - h - py1 px2 = w_pad - w - px1 padded = np.pad(image, ((py1, py2), (px1, px2), (0, 0)), mode="symmetric") mean = CHANNEL_MEAN[np.newaxis, np.newaxis, :] normed = (padded - mean) / 255.0 tensor = torch.from_numpy(normed.transpose(2, 0, 1)).unsqueeze(0).float().to(DEVICE) logits = model(tensor) probs = F.softmax(logits, dim=1).cpu().numpy()[0] return probs[1, py1: py1 + h, px1: px1 + w] def read_image_file(path: str) -> np.ndarray: try: with rasterio.open(path) as src: data = src.read() return np.moveaxis(data, 0, -1).astype(np.float32) except: return np.array(Image.open(path), dtype=np.float32) def fetch_esri_tile(lat, lon, zoom): lat_rad = math.radians(lat) n = 2.0 ** zoom xtile = int((lon + 180.0) / 360.0 * n) ytile = int((1.0 - math.asinh(math.tan(lat_rad)) / math.pi) / 2.0 * n) url = f"https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{int(zoom)}/{ytile}/{xtile}" req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'}) with urllib.request.urlopen(req) as response: content = response.read() nparr = np.frombuffer(content, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if img is None: raise ValueError("Failed to fetch tile from ESRI.") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img def process_file(mode, file_obj, url, lat, lon, zoom): if mode == "File Upload": if not file_obj: raise gr.Error("Please upload an image file.") path = file_obj.name image_f32 = read_image_file(path) elif mode == "Image URL": if not url or url.strip() == "": raise gr.Error("Please provide an image URL.") import urllib.request import tempfile try: req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64)'}) with urllib.request.urlopen(req) as response: content_type = response.info().get_content_type() if not content_type.startswith('image/') and not content_type == 'application/octet-stream': raise ValueError(f"URL returned {content_type}, expected an image.") content = response.read() with tempfile.NamedTemporaryFile(delete=False, suffix=".tif") as tmp: tmp.write(content) path = tmp.name except Exception as e: raise gr.Error(f"Failed to download image from URL: {e}") image_f32 = read_image_file(path) elif mode == "Coordinates": if lat is None or lon is None: raise gr.Error("Please enter valid latitude and longitude.") try: image_f32 = fetch_esri_tile(lat, lon, zoom).astype(np.float32) except Exception as e: raise gr.Error(f"Failed to fetch satellite imagery for coordinates: {e}") # Ensure exactly 3 channels if image_f32.ndim == 3 and image_f32.shape[-1] > 3: image_f32 = image_f32[:, :, :3] elif image_f32.ndim == 2: image_f32 = np.stack([image_f32] * 3, axis=-1) elif image_f32.shape[-1] == 1: image_f32 = np.concatenate([image_f32] * 3, axis=-1) model = load_model() probs = predict_tile(model, image_f32) # 1. Input Image input_uint8 = np.clip(image_f32, 0, 255).astype(np.uint8) # 2. Predicted Building Score (Heatmap using Viridis) heatmap_bgr = cv2.applyColorMap((probs * 255).astype(np.uint8), cv2.COLORMAP_VIRIDIS) heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB) # 3. Input + Predicted Buildings (Red Overlay) threshold = 0.5 binary_mask = (probs > threshold) overlay_rgb = input_uint8.copy() color = np.array([255, 100, 100], dtype=np.float32) # Light red # Blend where mask is true overlay_rgb[binary_mask] = (overlay_rgb[binary_mask] * 0.5 + color * 0.5).astype(np.uint8) return input_uint8, heatmap_rgb, overlay_rgb DESCRIPTION = """ # SpaceNet Building Detection This model is a high-performance **U-Net** architecture with residual connections, trained to detect precise building footprints from high-resolution satellite imagery. It was trained on the **SpaceNet Area of Rio de Janeiro** dataset. This demo showcases how the deep learning model can segment and highlight individual building structures in complex urban environments. More details can be found [here](https://github.com/HarshShinde0/spacenet). The user needs to provide a 3-band (RGB) satellite geotiff image (.tif), or a standard image file (.png, .jpg). The model will output the raw imagery, a heatmap of the predicted building score, and the image overlaid with the predicted building masks. """ custom_css = """ #outputs-row { margin-top: 20px; } """ with gr.Blocks(title="SpaceNet Building Detection") as demo: gr.Markdown(DESCRIPTION) input_mode = gr.Radio(["File Upload", "Image URL", "Coordinates"], label="Input Method", value="File Upload") with gr.Group(visible=True) as group_file: file_input = gr.File(label="Upload File", file_types=["image", ".tif", ".tiff"]) with gr.Group(visible=False) as group_url: url_input = gr.Textbox( label="Image URL", placeholder="https://example.com/satellite_image.tif", value="https://cms.ongeo-intelligence.com/uploads/xlarge_webp_aerial_view_of_an_urbanized_area_2023_11_27_05_30_10_utc_1_ed20b8ab46.webp" ) with gr.Group(visible=False) as group_coords: gr.Markdown("*Fetches a 256x256 high-resolution satellite tile from ESRI World Imagery.*") with gr.Row(): lat_input = gr.Number(label="Latitude", value=41.506734) lon_input = gr.Number(label="Longitude", value=-81.679967) zoom_input = gr.Slider(minimum=15, maximum=19, value=17, step=1, label="Zoom Level") def update_visibility(mode): return [ gr.update(visible=mode=="File Upload"), gr.update(visible=mode=="Image URL"), gr.update(visible=mode=="Coordinates") ] input_mode.change( fn=update_visibility, inputs=input_mode, outputs=[group_file, group_url, group_coords] ) submit_btn = gr.Button("Submit", variant="primary") with gr.Row(elem_id="outputs-row"): output_orig = gr.Image(label="Input Image") output_score = gr.Image(label="Predicted Building Score") output_overlay = gr.Image(label="Input + Predicted Buildings") gr.Markdown("### Examples") base_dir = os.path.dirname(os.path.abspath(__file__)) gr.Examples( examples=[ os.path.join(base_dir, "examples/3band_AOI_1_RIO_img1482.tif"), os.path.join(base_dir, "examples/3band_AOI_1_RIO_img2658.tif"), os.path.join(base_dir, "examples/3band_AOI_1_RIO_img5133.tif") ], inputs=file_input, ) submit_btn.click( fn=process_file, inputs=[input_mode, file_input, url_input, lat_input, lon_input, zoom_input], outputs=[output_orig, output_score, output_overlay] ) if __name__ == "__main__": load_model() demo.launch(css=custom_css)