| """ |
| SpaceNet Building Detection |
| """ |
|
|
| import math |
| import os |
| import cv2 |
| import gradio as gr |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import urllib.request |
| from huggingface_hub import hf_hub_download |
| import rasterio |
| from PIL import Image |
|
|
| from models.unet import UNet |
|
|
| MODEL_REPO = "harshinde/spacenet" |
| MODEL_FILENAME = "model.safetensors" |
| CHANNEL_MEAN = np.array([71.2274, 78.3385, 56.2296], dtype=np.float32) |
|
|
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| MODEL = None |
|
|
| def load_model() -> UNet: |
| global MODEL |
| if MODEL is not None: |
| return MODEL |
|
|
| weights_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME) |
| model = UNet(in_channels=3, num_classes=2, base_features=64, depth=4, dropout=0.0) |
|
|
| try: |
| from safetensors.torch import load_file |
| state_dict = load_file(weights_path, device=str(DEVICE)) |
| except ImportError: |
| state_dict = torch.load(weights_path, map_location=DEVICE, weights_only=False) |
|
|
| if "model_state_dict" in state_dict: |
| model.load_state_dict(state_dict["model_state_dict"]) |
| else: |
| model.load_state_dict(state_dict) |
|
|
| model = model.to(DEVICE).eval() |
| MODEL = model |
| return MODEL |
|
|
| @torch.no_grad() |
| def predict_tile(model: UNet, image: np.ndarray) -> np.ndarray: |
| h, w = image.shape[:2] |
| h_pad = int(math.ceil(h / 16) * 16) |
| w_pad = int(math.ceil(w / 16) * 16) |
| py1 = (h_pad - h) // 2 |
| px1 = (w_pad - w) // 2 |
| py2 = h_pad - h - py1 |
| px2 = w_pad - w - px1 |
|
|
| padded = np.pad(image, ((py1, py2), (px1, px2), (0, 0)), mode="symmetric") |
| mean = CHANNEL_MEAN[np.newaxis, np.newaxis, :] |
| normed = (padded - mean) / 255.0 |
| tensor = torch.from_numpy(normed.transpose(2, 0, 1)).unsqueeze(0).float().to(DEVICE) |
|
|
| logits = model(tensor) |
| probs = F.softmax(logits, dim=1).cpu().numpy()[0] |
| return probs[1, py1: py1 + h, px1: px1 + w] |
|
|
| def read_image_file(path: str) -> np.ndarray: |
| try: |
| with rasterio.open(path) as src: |
| data = src.read() |
| return np.moveaxis(data, 0, -1).astype(np.float32) |
| except: |
| return np.array(Image.open(path), dtype=np.float32) |
|
|
| def fetch_esri_tile(lat, lon, zoom): |
| lat_rad = math.radians(lat) |
| n = 2.0 ** zoom |
| xtile = int((lon + 180.0) / 360.0 * n) |
| ytile = int((1.0 - math.asinh(math.tan(lat_rad)) / math.pi) / 2.0 * n) |
| |
| url = f"https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{int(zoom)}/{ytile}/{xtile}" |
| req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'}) |
| with urllib.request.urlopen(req) as response: |
| content = response.read() |
| |
| nparr = np.frombuffer(content, np.uint8) |
| img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
| if img is None: |
| raise ValueError("Failed to fetch tile from ESRI.") |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| return img |
|
|
| def process_file(mode, file_obj, url, lat, lon, zoom): |
| if mode == "File Upload": |
| if not file_obj: |
| raise gr.Error("Please upload an image file.") |
| path = file_obj.name |
| image_f32 = read_image_file(path) |
| |
| elif mode == "Image URL": |
| if not url or url.strip() == "": |
| raise gr.Error("Please provide an image URL.") |
| import urllib.request |
| import tempfile |
| try: |
| req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64)'}) |
| with urllib.request.urlopen(req) as response: |
| content_type = response.info().get_content_type() |
| if not content_type.startswith('image/') and not content_type == 'application/octet-stream': |
| raise ValueError(f"URL returned {content_type}, expected an image.") |
| content = response.read() |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".tif") as tmp: |
| tmp.write(content) |
| path = tmp.name |
| except Exception as e: |
| raise gr.Error(f"Failed to download image from URL: {e}") |
| image_f32 = read_image_file(path) |
| |
| elif mode == "Coordinates": |
| if lat is None or lon is None: |
| raise gr.Error("Please enter valid latitude and longitude.") |
| try: |
| image_f32 = fetch_esri_tile(lat, lon, zoom).astype(np.float32) |
| except Exception as e: |
| raise gr.Error(f"Failed to fetch satellite imagery for coordinates: {e}") |
|
|
| |
| if image_f32.ndim == 3 and image_f32.shape[-1] > 3: |
| image_f32 = image_f32[:, :, :3] |
| elif image_f32.ndim == 2: |
| image_f32 = np.stack([image_f32] * 3, axis=-1) |
| elif image_f32.shape[-1] == 1: |
| image_f32 = np.concatenate([image_f32] * 3, axis=-1) |
|
|
| model = load_model() |
| probs = predict_tile(model, image_f32) |
| |
| |
| input_uint8 = np.clip(image_f32, 0, 255).astype(np.uint8) |
| |
| |
| heatmap_bgr = cv2.applyColorMap((probs * 255).astype(np.uint8), cv2.COLORMAP_VIRIDIS) |
| heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB) |
| |
| |
| threshold = 0.5 |
| binary_mask = (probs > threshold) |
| overlay_rgb = input_uint8.copy() |
| color = np.array([255, 100, 100], dtype=np.float32) |
| |
| |
| overlay_rgb[binary_mask] = (overlay_rgb[binary_mask] * 0.5 + color * 0.5).astype(np.uint8) |
|
|
| return input_uint8, heatmap_rgb, overlay_rgb |
|
|
| DESCRIPTION = """ |
| # SpaceNet Building Detection |
| |
| This model is a high-performance **U-Net** architecture with residual connections, trained to detect precise building footprints from high-resolution satellite imagery. It was trained on the **SpaceNet Area of Rio de Janeiro** dataset. This demo showcases how the deep learning model can segment and highlight individual building structures in complex urban environments. More details can be found [here](https://github.com/HarshShinde0/spacenet). |
| |
| The user needs to provide a 3-band (RGB) satellite geotiff image (.tif), or a standard image file (.png, .jpg). The model will output the raw imagery, a heatmap of the predicted building score, and the image overlaid with the predicted building masks. |
| """ |
|
|
| custom_css = """ |
| #outputs-row { margin-top: 20px; } |
| """ |
|
|
| with gr.Blocks(title="SpaceNet Building Detection") as demo: |
| gr.Markdown(DESCRIPTION) |
|
|
| input_mode = gr.Radio(["File Upload", "Image URL", "Coordinates"], label="Input Method", value="File Upload") |
| |
| with gr.Group(visible=True) as group_file: |
| file_input = gr.File(label="Upload File", file_types=["image", ".tif", ".tiff"]) |
| |
| with gr.Group(visible=False) as group_url: |
| url_input = gr.Textbox( |
| label="Image URL", |
| placeholder="https://example.com/satellite_image.tif", |
| value="https://cms.ongeo-intelligence.com/uploads/xlarge_webp_aerial_view_of_an_urbanized_area_2023_11_27_05_30_10_utc_1_ed20b8ab46.webp" |
| ) |
| |
| with gr.Group(visible=False) as group_coords: |
| gr.Markdown("*Fetches a 256x256 high-resolution satellite tile from ESRI World Imagery.*") |
| with gr.Row(): |
| lat_input = gr.Number(label="Latitude", value=41.506734) |
| lon_input = gr.Number(label="Longitude", value=-81.679967) |
| zoom_input = gr.Slider(minimum=15, maximum=19, value=17, step=1, label="Zoom Level") |
| |
| def update_visibility(mode): |
| return [ |
| gr.update(visible=mode=="File Upload"), |
| gr.update(visible=mode=="Image URL"), |
| gr.update(visible=mode=="Coordinates") |
| ] |
| |
| input_mode.change( |
| fn=update_visibility, |
| inputs=input_mode, |
| outputs=[group_file, group_url, group_coords] |
| ) |
| |
| submit_btn = gr.Button("Submit", variant="primary") |
| |
| with gr.Row(elem_id="outputs-row"): |
| output_orig = gr.Image(label="Input Image") |
| output_score = gr.Image(label="Predicted Building Score") |
| output_overlay = gr.Image(label="Input + Predicted Buildings") |
| |
| gr.Markdown("### Examples") |
| base_dir = os.path.dirname(os.path.abspath(__file__)) |
| gr.Examples( |
| examples=[ |
| os.path.join(base_dir, "examples/3band_AOI_1_RIO_img1482.tif"), |
| os.path.join(base_dir, "examples/3band_AOI_1_RIO_img2658.tif"), |
| os.path.join(base_dir, "examples/3band_AOI_1_RIO_img5133.tif") |
| ], |
| inputs=file_input, |
| ) |
|
|
| submit_btn.click( |
| fn=process_file, |
| inputs=[input_mode, file_input, url_input, lat_input, lon_input, zoom_input], |
| outputs=[output_orig, output_score, output_overlay] |
| ) |
|
|
| if __name__ == "__main__": |
| load_model() |
| demo.launch(css=custom_css) |
|
|