road_omniview / app.py
Vinit710's picture
Update app.py
18c0b24 verified
import gradio as gr
import tensorflow as tf
import numpy as np
import cv2
# Load your model (convert .h5 path to model)
MODEL_PATH = "road_detection_resnet_e50.h5" # put your .h5 in the same directory as app.py
model = tf.keras.models.load_model(MODEL_PATH)
def predict_road_mask(image):
"""
Takes an input image (numpy array), resizes & normalizes it,
runs the road detection model, and returns probability + binary masks.
"""
target_size = (256, 256) # must match your model's expected size
# Convert image to RGB just in case
img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
img_resized = cv2.resize(img_rgb, target_size)
img_normalized = img_resized / 255.0
img_input = np.expand_dims(img_normalized, axis=0)
# Run model
pred = model.predict(img_input)[0] # shape: (256, 256, 1)
pred_binary = (pred > 0.5).astype(np.uint8) * 255 # for better visualization
# Convert to displayable format
prob_display = (pred.squeeze() * 255).astype(np.uint8)
binary_display = pred_binary.squeeze()
return prob_display, binary_display
# Build Gradio interface
demo = gr.Interface(
fn=predict_road_mask,
inputs=gr.Image(type="numpy", label="Upload Satellite Image"),
outputs=[
gr.Image(type="numpy", label="Road Probability Mask"),
gr.Image(type="numpy", label="Road Binary Mask"),
],
title="Road Detection from Satellite Images",
description="Upload a satellite image to detect roads using a trained segmentation model.",
examples=[],
)
if __name__ == "__main__":
demo.launch()