Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.keras.models import load_model | |
| from PIL import Image | |
| # ------------------------------- | |
| # Load Model and Data | |
| # ------------------------------- | |
| MODEL_PATH = "satellite_segmentation_unet.h5" | |
| X_TEST_PATH = "X_test.npy" | |
| Y_TEST_PATH = "y_test.npy" | |
| TOTAL_CLASSES_PATH = "total_classes.npy" | |
| print("Loading model and test data...") | |
| model = load_model(MODEL_PATH, compile=False) | |
| X_test = np.load(X_TEST_PATH) | |
| y_test = np.load(Y_TEST_PATH) | |
| total_classes = int(np.load(TOTAL_CLASSES_PATH)[0]) | |
| # ------------------------------- | |
| # Class color map | |
| # ------------------------------- | |
| CLASS_COLORS = { | |
| 0: [226, 169, 41], # water | |
| 1: [132, 41, 246], # land | |
| 2: [110, 193, 228], # road | |
| 3: [60, 16, 152], # building | |
| 4: [254, 221, 58], # vegetation | |
| 5: [155, 155, 155], # unlabeled | |
| } | |
| # ------------------------------- | |
| # Helper Functions | |
| # ------------------------------- | |
| def decode_segmentation(mask): | |
| """Convert model output or label mask to RGB.""" | |
| if mask.ndim == 4: # Model output shape (1, H, W, num_classes) | |
| mask = np.argmax(mask[0], axis=-1) | |
| elif mask.ndim == 3 and mask.shape[-1] == total_classes: | |
| mask = np.argmax(mask, axis=-1) | |
| elif mask.ndim == 3: | |
| mask = mask.squeeze() | |
| color_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) | |
| for class_idx, color in CLASS_COLORS.items(): | |
| color_mask[mask == class_idx] = color | |
| return color_mask | |
| def predict_segmentation(index): | |
| """Run model inference and return all visualizations.""" | |
| try: | |
| index = int(index) | |
| image = (X_test[index] * 255).astype(np.uint8) | |
| # Model prediction | |
| pred = model.predict(np.expand_dims(X_test[index], axis=0)) | |
| pred_mask = decode_segmentation(pred) | |
| # Ground truth | |
| gt_mask = decode_segmentation(y_test[index]) | |
| return image, gt_mask, pred_mask | |
| except Exception as e: | |
| print("Error during prediction:", e) | |
| return None, None, None | |
| # ------------------------------- | |
| # Gradio Interface | |
| # ------------------------------- | |
| title = "🛰️ Satellite Segmentation (U-Net)" | |
| description = """ | |
| Select one of the **test images** from the dataset to visualize: | |
| - The original input image | |
| - The ground truth mask | |
| - The predicted segmentation mask from the trained **U-Net** | |
| """ | |
| indices = list(range(min(50, len(X_test)))) | |
| interface = gr.Interface( | |
| fn=predict_segmentation, | |
| inputs=gr.Dropdown( | |
| choices=[str(i) for i in indices], | |
| label="Select Test Image Index", | |
| value="0", | |
| interactive=True, | |
| ), | |
| outputs=[ | |
| gr.Image(label="Original Image"), | |
| gr.Image(label="Ground Truth Mask"), | |
| gr.Image(label="Predicted Mask"), | |
| ], | |
| title=title, | |
| description=description, | |
| allow_flagging="never", | |
| theme="gradio/soft", | |
| ) | |
| # ------------------------------- | |
| # Launch | |
| # ------------------------------- | |
| if __name__ == "__main__": | |
| interface.launch(server_name="0.0.0.0", server_port=7860) | |