rankseg / app.py
statmlben's picture
Upload app.py
4522963 verified
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
from rankseg import RankSEG
import numpy as np
import functools
# 1. Model Loader with Caching
# Support multiple models for user selection
@functools.lru_cache(maxsize=4)
def load_model(model_name):
print(f"Loading model: {model_name}...")
if model_name == "DeepLabV3+ (ResNet50)":
try:
weights = models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT
model = models.segmentation.deeplabv3_resnet50(weights=weights)
except:
model = models.segmentation.deeplabv3_resnet50(pretrained=True)
elif model_name == "DeepLabV3+ (ResNet101)":
try:
weights = models.segmentation.DeepLabV3_ResNet101_Weights.DEFAULT
model = models.segmentation.deeplabv3_resnet101(weights=weights)
except:
model = models.segmentation.deeplabv3_resnet101(pretrained=True)
elif model_name == "DeepLabV3+ (MobileNetV3)":
try:
weights = models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT
model = models.segmentation.deeplabv3_mobilenet_v3_large(weights=weights)
except:
model = models.segmentation.deeplabv3_mobilenet_v3_large(pretrained=True)
elif model_name == "LRASPP (MobileNetV3)":
try:
weights = models.segmentation.LRASPP_MobileNet_V3_Large_Weights.DEFAULT
model = models.segmentation.lraspp_mobilenet_v3_large(weights=weights)
except:
model = models.segmentation.lraspp_mobilenet_v3_large(pretrained=True)
model.eval()
if torch.cuda.is_available():
model = model.cuda()
return model
# Initialize default model
current_model = load_model("DeepLabV3+ (ResNet50)")
# 2. Define Transformations
preprocess = transforms.Compose([
transforms.Resize(520),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Color palette for visualization (PASCAL VOC style)
def get_palette():
return [
(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0),
(0, 0, 128), (128, 0, 128), (0, 128, 128), (128, 128, 128),
(64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0),
(64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
(0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0),
(0, 64, 128)
]
def colorize_mask(mask):
# mask: (H, W) numpy array
palette = get_palette()
h, w = mask.shape
color_mask = np.zeros((h, w, 3), dtype=np.uint8)
for label, color in enumerate(palette):
color_mask[mask == label] = color
return color_mask
# 3. Inference Function
def predict(image, model_selection):
if image is None:
return None, None
# Reload model if changed (or use cached)
# Note: In a real heavy-traffic app, we might want to keep models in memory dict
model = load_model(model_selection)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Preprocess
input_tensor = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor)['out']
probs = F.softmax(output, dim=1) # (1, 21, H, W)
# --- METHOD 1: ARGMAX ---
argmax_pred = torch.argmax(probs, dim=1).squeeze().cpu().numpy()
argmax_vis = colorize_mask(argmax_pred)
# --- METHOD 2: RANKSEG ---
# Optimize for Dice
rankseg = RankSEG(metric='dice', solver='RMA')
rankseg_pred_tensor = rankseg.predict(probs)
rankseg_pred = rankseg_pred_tensor.squeeze().cpu().numpy()
rankseg_vis = colorize_mask(rankseg_pred)
return argmax_vis, rankseg_vis
# 4. Gradio Interface
title = "🧩 RankSEG: Optimize Segmentation Metrics without Retraining"
description = """
<div style="text-align: center;">
<p>
<b>RankSEG</b> is a state-of-the-art <b>test-time optimization</b> framework for segmentation.
<br>
It replaces the standard <code>argmax</code> with a rigorous ranking-based solver to directly maximize <b>Dice</b> or <b>IoU</b> metrics during inference.
<br>
βœ… <b>Plug-and-Play</b>: Works with ANY pre-trained model (DeepLab, UNet, SegFormer, etc.)
<br>
βœ… <b>No Retraining</b>: Zero-shot performance boost
<br>
βœ… <b>Versatile</b>: Supports Binary, Multi-class, and Multi-label tasks
</p>
<div style="display: flex; justify-content: center; gap: 10px; margin-top: 10px;">
<a href="https://github.com/rankseg/rankseg" target="_blank">[GitHub]</a>
<a href="https://rankseg.readthedocs.io/" target="_blank">[Documentation]</a>
</div>
</div>
"""
article = """
### How it works
1. **Input**: The pre-trained model outputs probability maps.
2. **Optimization**: Instead of standard `argmax`, RankSEG uses a ranking-based solver (RMA) to select pixels that maximize the global Dice/IoU score.
3. **Result**: You typically get cleaner boundaries and better object recovery, consistent with the target metric.
"""
# Example images
examples = [
["demo1.jpg", "DeepLabV3+ (ResNet50)"],
["demo2.png", "LRASPP (MobileNetV3)"]
]
demo = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Dropdown(
choices=["DeepLabV3+ (ResNet50)", "DeepLabV3+ (ResNet101)", "DeepLabV3+ (MobileNetV3)", "LRASPP (MobileNetV3)"],
value="DeepLabV3+ (ResNet50)",
label="Select Pre-trained Model"
)
],
outputs=[
gr.Image(label="Standard Argmax Prediction"),
gr.Image(label="RankSEG Optimized (Dice)")
],
title=title,
description=description,
article=article,
examples=examples,
cache_examples=False
)
if __name__ == "__main__":
demo.launch()