Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| from torchvision import transforms | |
| from huggingface_hub import hf_hub_download | |
| import segmentation_models_pytorch as smp | |
| import numpy as np | |
| # Set the number of output classes (from your label_colors.txt, you have 4 classes) | |
| NUM_CLASSES = 4 | |
| # Define a mapping from class indices to RGB colors | |
| # For example: background: black, oil: (255, 0, 124), others: (255, 204, 51), water: (51, 221, 255) | |
| COLOR_MAPPING = { | |
| 0: [0, 0, 0], | |
| 1: [255, 0, 124], | |
| 2: [255, 204, 51], | |
| 3: [51, 221, 255] | |
| } | |
| def colorize_mask(mask): | |
| """ | |
| Convert a 2D mask (with class indices) into a color image. | |
| Args: | |
| mask (np.ndarray): 2D numpy array with class indices. | |
| Returns: | |
| np.ndarray: Color image (H x W x 3) with each class colored according to COLOR_MAPPING. | |
| """ | |
| h, w = mask.shape | |
| color_mask = np.zeros((h, w, 3), dtype=np.uint8) | |
| for cls, color in COLOR_MAPPING.items(): | |
| color_mask[mask == cls] = color | |
| return color_mask | |
| # Download the model state dictionary from your Hugging Face repository | |
| model_path = hf_hub_download(repo_id="TheArchitect416/oil-spill-segmentation-model", filename="model.pth") | |
| # Create the model using segmentation_models_pytorch. | |
| # This must match the architecture used during training. | |
| model = smp.Unet( | |
| encoder_name="resnet34", # For example, resnet34 was used in training. | |
| encoder_weights="imagenet", # Use pretrained weights from ImageNet. | |
| in_channels=3, # RGB images. | |
| classes=NUM_CLASSES # Number of segmentation classes. | |
| ) | |
| # Load the state dict (mapping the keys appropriately) | |
| model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
| model.eval() | |
| # Define preprocessing transforms (should match what was used during training) | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=(0.485, 0.456, 0.406), # ImageNet means | |
| std=(0.229, 0.224, 0.225)) | |
| ]) | |
| # Define the inference function | |
| def predict(image): | |
| """ | |
| Accepts a PIL image, preprocesses it, runs the model, | |
| and returns the predicted colored segmentation mask. | |
| """ | |
| # Preprocess the image | |
| input_tensor = preprocess(image).unsqueeze(0) # shape: [1, 3, 256, 256] | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| # Get the predicted class for each pixel | |
| pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy().astype(np.uint8) | |
| # Convert the 2D class-index mask to a color mask | |
| colored_mask = colorize_mask(pred_mask) | |
| return colored_mask | |
| # Create a Gradio interface | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Image(type="numpy"), | |
| title="Oil Spill Segmentation", | |
| description="Segment oil spills in aerial images." | |
| ) | |
| print("Gradio version:", gr.__version__) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| iface.queue() | |
| iface.launch(server_name="0.0.0.0", server_port=7860) | |