import gradio as gr import numpy as np import tensorflow as tf from tensorflow.keras.models import load_model from PIL import Image # ------------------------------- # Load Model and Data # ------------------------------- MODEL_PATH = "satellite_segmentation_unet.h5" X_TEST_PATH = "X_test.npy" Y_TEST_PATH = "y_test.npy" TOTAL_CLASSES_PATH = "total_classes.npy" print("Loading model and test data...") model = load_model(MODEL_PATH, compile=False) X_test = np.load(X_TEST_PATH) y_test = np.load(Y_TEST_PATH) total_classes = int(np.load(TOTAL_CLASSES_PATH)[0]) # ------------------------------- # Class color map # ------------------------------- CLASS_COLORS = { 0: [226, 169, 41], # water 1: [132, 41, 246], # land 2: [110, 193, 228], # road 3: [60, 16, 152], # building 4: [254, 221, 58], # vegetation 5: [155, 155, 155], # unlabeled } # ------------------------------- # Helper Functions # ------------------------------- def decode_segmentation(mask): """Convert model output or label mask to RGB.""" if mask.ndim == 4: # Model output shape (1, H, W, num_classes) mask = np.argmax(mask[0], axis=-1) elif mask.ndim == 3 and mask.shape[-1] == total_classes: mask = np.argmax(mask, axis=-1) elif mask.ndim == 3: mask = mask.squeeze() color_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) for class_idx, color in CLASS_COLORS.items(): color_mask[mask == class_idx] = color return color_mask def predict_segmentation(index): """Run model inference and return all visualizations.""" try: index = int(index) image = (X_test[index] * 255).astype(np.uint8) # Model prediction pred = model.predict(np.expand_dims(X_test[index], axis=0)) pred_mask = decode_segmentation(pred) # Ground truth gt_mask = decode_segmentation(y_test[index]) return image, gt_mask, pred_mask except Exception as e: print("Error during prediction:", e) return None, None, None # ------------------------------- # Gradio Interface # ------------------------------- title = "🛰️ Satellite Segmentation (U-Net)" description = """ Select one of the **test images** from the dataset to visualize: - The original input image - The ground truth mask - The predicted segmentation mask from the trained **U-Net** """ indices = list(range(min(50, len(X_test)))) interface = gr.Interface( fn=predict_segmentation, inputs=gr.Dropdown( choices=[str(i) for i in indices], label="Select Test Image Index", value="0", interactive=True, ), outputs=[ gr.Image(label="Original Image"), gr.Image(label="Ground Truth Mask"), gr.Image(label="Predicted Mask"), ], title=title, description=description, allow_flagging="never", theme="gradio/soft", ) # ------------------------------- # Launch # ------------------------------- if __name__ == "__main__": interface.launch(server_name="0.0.0.0", server_port=7860)