spacenet / app.py
harshinde's picture
Create app.py
50ae058 verified
"""
SpaceNet Building Detection
"""
import math
import os
import cv2
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
import urllib.request
from huggingface_hub import hf_hub_download
import rasterio
from PIL import Image
from models.unet import UNet
MODEL_REPO = "harshinde/spacenet"
MODEL_FILENAME = "model.safetensors"
CHANNEL_MEAN = np.array([71.2274, 78.3385, 56.2296], dtype=np.float32)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL = None
def load_model() -> UNet:
global MODEL
if MODEL is not None:
return MODEL
weights_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
model = UNet(in_channels=3, num_classes=2, base_features=64, depth=4, dropout=0.0)
try:
from safetensors.torch import load_file
state_dict = load_file(weights_path, device=str(DEVICE))
except ImportError:
state_dict = torch.load(weights_path, map_location=DEVICE, weights_only=False)
if "model_state_dict" in state_dict:
model.load_state_dict(state_dict["model_state_dict"])
else:
model.load_state_dict(state_dict)
model = model.to(DEVICE).eval()
MODEL = model
return MODEL
@torch.no_grad()
def predict_tile(model: UNet, image: np.ndarray) -> np.ndarray:
h, w = image.shape[:2]
h_pad = int(math.ceil(h / 16) * 16)
w_pad = int(math.ceil(w / 16) * 16)
py1 = (h_pad - h) // 2
px1 = (w_pad - w) // 2
py2 = h_pad - h - py1
px2 = w_pad - w - px1
padded = np.pad(image, ((py1, py2), (px1, px2), (0, 0)), mode="symmetric")
mean = CHANNEL_MEAN[np.newaxis, np.newaxis, :]
normed = (padded - mean) / 255.0
tensor = torch.from_numpy(normed.transpose(2, 0, 1)).unsqueeze(0).float().to(DEVICE)
logits = model(tensor)
probs = F.softmax(logits, dim=1).cpu().numpy()[0]
return probs[1, py1: py1 + h, px1: px1 + w]
def read_image_file(path: str) -> np.ndarray:
try:
with rasterio.open(path) as src:
data = src.read()
return np.moveaxis(data, 0, -1).astype(np.float32)
except:
return np.array(Image.open(path), dtype=np.float32)
def fetch_esri_tile(lat, lon, zoom):
lat_rad = math.radians(lat)
n = 2.0 ** zoom
xtile = int((lon + 180.0) / 360.0 * n)
ytile = int((1.0 - math.asinh(math.tan(lat_rad)) / math.pi) / 2.0 * n)
url = f"https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{int(zoom)}/{ytile}/{xtile}"
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
with urllib.request.urlopen(req) as response:
content = response.read()
nparr = np.frombuffer(content, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
raise ValueError("Failed to fetch tile from ESRI.")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def process_file(mode, file_obj, url, lat, lon, zoom):
if mode == "File Upload":
if not file_obj:
raise gr.Error("Please upload an image file.")
path = file_obj.name
image_f32 = read_image_file(path)
elif mode == "Image URL":
if not url or url.strip() == "":
raise gr.Error("Please provide an image URL.")
import urllib.request
import tempfile
try:
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64)'})
with urllib.request.urlopen(req) as response:
content_type = response.info().get_content_type()
if not content_type.startswith('image/') and not content_type == 'application/octet-stream':
raise ValueError(f"URL returned {content_type}, expected an image.")
content = response.read()
with tempfile.NamedTemporaryFile(delete=False, suffix=".tif") as tmp:
tmp.write(content)
path = tmp.name
except Exception as e:
raise gr.Error(f"Failed to download image from URL: {e}")
image_f32 = read_image_file(path)
elif mode == "Coordinates":
if lat is None or lon is None:
raise gr.Error("Please enter valid latitude and longitude.")
try:
image_f32 = fetch_esri_tile(lat, lon, zoom).astype(np.float32)
except Exception as e:
raise gr.Error(f"Failed to fetch satellite imagery for coordinates: {e}")
# Ensure exactly 3 channels
if image_f32.ndim == 3 and image_f32.shape[-1] > 3:
image_f32 = image_f32[:, :, :3]
elif image_f32.ndim == 2:
image_f32 = np.stack([image_f32] * 3, axis=-1)
elif image_f32.shape[-1] == 1:
image_f32 = np.concatenate([image_f32] * 3, axis=-1)
model = load_model()
probs = predict_tile(model, image_f32)
# 1. Input Image
input_uint8 = np.clip(image_f32, 0, 255).astype(np.uint8)
# 2. Predicted Building Score (Heatmap using Viridis)
heatmap_bgr = cv2.applyColorMap((probs * 255).astype(np.uint8), cv2.COLORMAP_VIRIDIS)
heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB)
# 3. Input + Predicted Buildings (Red Overlay)
threshold = 0.5
binary_mask = (probs > threshold)
overlay_rgb = input_uint8.copy()
color = np.array([255, 100, 100], dtype=np.float32) # Light red
# Blend where mask is true
overlay_rgb[binary_mask] = (overlay_rgb[binary_mask] * 0.5 + color * 0.5).astype(np.uint8)
return input_uint8, heatmap_rgb, overlay_rgb
DESCRIPTION = """
# SpaceNet Building Detection
This model is a high-performance **U-Net** architecture with residual connections, trained to detect precise building footprints from high-resolution satellite imagery. It was trained on the **SpaceNet Area of Rio de Janeiro** dataset. This demo showcases how the deep learning model can segment and highlight individual building structures in complex urban environments. More details can be found [here](https://github.com/HarshShinde0/spacenet).
The user needs to provide a 3-band (RGB) satellite geotiff image (.tif), or a standard image file (.png, .jpg). The model will output the raw imagery, a heatmap of the predicted building score, and the image overlaid with the predicted building masks.
"""
custom_css = """
#outputs-row { margin-top: 20px; }
"""
with gr.Blocks(title="SpaceNet Building Detection") as demo:
gr.Markdown(DESCRIPTION)
input_mode = gr.Radio(["File Upload", "Image URL", "Coordinates"], label="Input Method", value="File Upload")
with gr.Group(visible=True) as group_file:
file_input = gr.File(label="Upload File", file_types=["image", ".tif", ".tiff"])
with gr.Group(visible=False) as group_url:
url_input = gr.Textbox(
label="Image URL",
placeholder="https://example.com/satellite_image.tif",
value="https://cms.ongeo-intelligence.com/uploads/xlarge_webp_aerial_view_of_an_urbanized_area_2023_11_27_05_30_10_utc_1_ed20b8ab46.webp"
)
with gr.Group(visible=False) as group_coords:
gr.Markdown("*Fetches a 256x256 high-resolution satellite tile from ESRI World Imagery.*")
with gr.Row():
lat_input = gr.Number(label="Latitude", value=41.506734)
lon_input = gr.Number(label="Longitude", value=-81.679967)
zoom_input = gr.Slider(minimum=15, maximum=19, value=17, step=1, label="Zoom Level")
def update_visibility(mode):
return [
gr.update(visible=mode=="File Upload"),
gr.update(visible=mode=="Image URL"),
gr.update(visible=mode=="Coordinates")
]
input_mode.change(
fn=update_visibility,
inputs=input_mode,
outputs=[group_file, group_url, group_coords]
)
submit_btn = gr.Button("Submit", variant="primary")
with gr.Row(elem_id="outputs-row"):
output_orig = gr.Image(label="Input Image")
output_score = gr.Image(label="Predicted Building Score")
output_overlay = gr.Image(label="Input + Predicted Buildings")
gr.Markdown("### Examples")
base_dir = os.path.dirname(os.path.abspath(__file__))
gr.Examples(
examples=[
os.path.join(base_dir, "examples/3band_AOI_1_RIO_img1482.tif"),
os.path.join(base_dir, "examples/3band_AOI_1_RIO_img2658.tif"),
os.path.join(base_dir, "examples/3band_AOI_1_RIO_img5133.tif")
],
inputs=file_input,
)
submit_btn.click(
fn=process_file,
inputs=[input_mode, file_input, url_input, lat_input, lon_input, zoom_input],
outputs=[output_orig, output_score, output_overlay]
)
if __name__ == "__main__":
load_model()
demo.launch(css=custom_css)