import gradio as gr import torch import torch.nn.functional as F from torchvision import models, transforms from PIL import Image from rankseg import RankSEG import numpy as np import functools # 1. Model Loader with Caching # Support multiple models for user selection @functools.lru_cache(maxsize=4) def load_model(model_name): print(f"Loading model: {model_name}...") if model_name == "DeepLabV3+ (ResNet50)": try: weights = models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT model = models.segmentation.deeplabv3_resnet50(weights=weights) except: model = models.segmentation.deeplabv3_resnet50(pretrained=True) elif model_name == "DeepLabV3+ (ResNet101)": try: weights = models.segmentation.DeepLabV3_ResNet101_Weights.DEFAULT model = models.segmentation.deeplabv3_resnet101(weights=weights) except: model = models.segmentation.deeplabv3_resnet101(pretrained=True) elif model_name == "DeepLabV3+ (MobileNetV3)": try: weights = models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT model = models.segmentation.deeplabv3_mobilenet_v3_large(weights=weights) except: model = models.segmentation.deeplabv3_mobilenet_v3_large(pretrained=True) elif model_name == "LRASPP (MobileNetV3)": try: weights = models.segmentation.LRASPP_MobileNet_V3_Large_Weights.DEFAULT model = models.segmentation.lraspp_mobilenet_v3_large(weights=weights) except: model = models.segmentation.lraspp_mobilenet_v3_large(pretrained=True) model.eval() if torch.cuda.is_available(): model = model.cuda() return model # Initialize default model current_model = load_model("DeepLabV3+ (ResNet50)") # 2. Define Transformations preprocess = transforms.Compose([ transforms.Resize(520), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Color palette for visualization (PASCAL VOC style) def get_palette(): return [ (0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128) ] def colorize_mask(mask): # mask: (H, W) numpy array palette = get_palette() h, w = mask.shape color_mask = np.zeros((h, w, 3), dtype=np.uint8) for label, color in enumerate(palette): color_mask[mask == label] = color return color_mask # 3. Inference Function def predict(image, model_selection): if image is None: return None, None # Reload model if changed (or use cached) # Note: In a real heavy-traffic app, we might want to keep models in memory dict model = load_model(model_selection) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # Preprocess input_tensor = preprocess(image).unsqueeze(0).to(device) with torch.no_grad(): output = model(input_tensor)['out'] probs = F.softmax(output, dim=1) # (1, 21, H, W) # --- METHOD 1: ARGMAX --- argmax_pred = torch.argmax(probs, dim=1).squeeze().cpu().numpy() argmax_vis = colorize_mask(argmax_pred) # --- METHOD 2: RANKSEG --- # Optimize for Dice rankseg = RankSEG(metric='dice', solver='RMA') rankseg_pred_tensor = rankseg.predict(probs) rankseg_pred = rankseg_pred_tensor.squeeze().cpu().numpy() rankseg_vis = colorize_mask(rankseg_pred) return argmax_vis, rankseg_vis # 4. Gradio Interface title = "🧩 RankSEG: Optimize Segmentation Metrics without Retraining" description = """
RankSEG is a state-of-the-art test-time optimization framework for segmentation.
It replaces the standard argmax with a rigorous ranking-based solver to directly maximize Dice or IoU metrics during inference.
✅ Plug-and-Play: Works with ANY pre-trained model (DeepLab, UNet, SegFormer, etc.)
✅ No Retraining: Zero-shot performance boost
✅ Versatile: Supports Binary, Multi-class, and Multi-label tasks