Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import segmentation_models_pytorch as smp | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from torchvision.utils import draw_segmentation_masks | |
| config = { | |
| "downsize_res": 512, | |
| "batch_size": 6, | |
| "epochs": 30, | |
| "lr": 3e-4, | |
| "model_architecture": "Unet", | |
| "model_config": { | |
| "encoder_name": "resnet34", | |
| "encoder_weights": "imagenet", | |
| "in_channels": 3, | |
| "classes": 7, | |
| }, | |
| } | |
| colors = [ | |
| (0, 255, 255), | |
| (255, 255, 0), | |
| (255, 0, 255), | |
| (0, 255, 0), | |
| (0, 0, 255), | |
| (255, 255, 255), | |
| (0, 0, 0), | |
| ] | |
| cp_path = "CP_epoch20.pth" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # load model | |
| model_architecture = getattr(smp, config["model_architecture"]) | |
| model = model_architecture(**config["model_config"]) | |
| model.load_state_dict(torch.load(cp_path, map_location=torch.device(device))) | |
| model.to(device) | |
| model.eval() | |
| # transforms | |
| downsize_t = transforms.Resize((config["downsize_res"], config["downsize_res"]), antialias=True) | |
| transform = transforms.Compose( | |
| [ | |
| transforms.ToTensor(), | |
| ] | |
| ) | |
| def label_to_onehot(mask: torch.Tensor, num_classes: int) -> torch.Tensor: | |
| """Transforms a tensor from label encoding to one hot encoding in boolean dtype""" | |
| dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2) | |
| return torch.permute( | |
| F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool), | |
| dims_p, | |
| ) | |
| def get_overlay(image: torch.Tensor, preds: torch.Tensor, alpha: float) -> torch.Tensor: | |
| """Generates the segmentation ovelay for an satellite image""" | |
| masks = label_to_onehot(preds.squeeze(), 7) | |
| overlay = draw_segmentation_masks(image, masks=masks, alpha=alpha, colors=colors) | |
| return overlay | |
| def hwc_to_chw(image_tensor: torch.Tensor) -> torch.Tensor: | |
| return torch.permute(image_tensor, (2, 0, 1)) | |
| def chw_to_hwc(image_tensor: torch.Tensor) -> torch.Tensor: | |
| return torch.permute(image_tensor, (1, 2, 0)) | |
| def segment(satellite_image: np.ndarray) -> tuple[np.ndarray, np.ndarray]: | |
| image_tensor = torch.from_numpy(satellite_image) | |
| image_tensor = hwc_to_chw(image_tensor) | |
| pil_image = transforms.functional.to_pil_image(image_tensor) | |
| # preprocess image | |
| X = transform(pil_image).unsqueeze(0) | |
| X = X.to(device) | |
| X_down = downsize_t(X) | |
| # forward pass | |
| logits = model(X_down) | |
| preds = torch.argmax(logits, 1).detach() | |
| # resize to evaluate with the original image | |
| preds = transforms.functional.resize(preds, X.shape[-2:], antialias=True) | |
| # get rbg formatted images | |
| segmentation_overlay = chw_to_hwc(get_overlay(image_tensor, preds, 0.2)).numpy() | |
| raw_segmentation = chw_to_hwc( | |
| get_overlay(torch.zeros_like(image_tensor), preds, 1) | |
| ).numpy() | |
| return raw_segmentation, segmentation_overlay | |
| inputs = gr.inputs.Image(label="Input Image") | |
| outputs = [gr.Image(label="Raw Segmentation"), gr.Image(label="Segmentation Overlay")] | |
| images_dir = "sample_sat_images/" | |
| examples = [f"{images_dir}/{image_id}" for image_id in os.listdir(images_dir)] | |
| title = "Satellite Images Landcover Classification" | |
| description = ( | |
| "Upload a satellite image from your computer or select one from" | |
| " the examples to automatically. The model will segment the landcover" | |
| " types from a preselected set of possible types." | |
| ) | |
| article = open("article.md", "r").read() | |
| iface = gr.Interface( | |
| segment, | |
| inputs, | |
| outputs, | |
| examples=examples, | |
| title=title, | |
| description=description, | |
| cache_examples=True, | |
| article=article, | |
| ) | |
| iface.launch() | |