|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torchvision import models, transforms |
|
|
from PIL import Image |
|
|
from rankseg import RankSEG |
|
|
import numpy as np |
|
|
import functools |
|
|
|
|
|
|
|
|
|
|
|
@functools.lru_cache(maxsize=4) |
|
|
def load_model(model_name): |
|
|
print(f"Loading model: {model_name}...") |
|
|
if model_name == "DeepLabV3+ (ResNet50)": |
|
|
try: |
|
|
weights = models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT |
|
|
model = models.segmentation.deeplabv3_resnet50(weights=weights) |
|
|
except: |
|
|
model = models.segmentation.deeplabv3_resnet50(pretrained=True) |
|
|
elif model_name == "DeepLabV3+ (ResNet101)": |
|
|
try: |
|
|
weights = models.segmentation.DeepLabV3_ResNet101_Weights.DEFAULT |
|
|
model = models.segmentation.deeplabv3_resnet101(weights=weights) |
|
|
except: |
|
|
model = models.segmentation.deeplabv3_resnet101(pretrained=True) |
|
|
elif model_name == "DeepLabV3+ (MobileNetV3)": |
|
|
try: |
|
|
weights = models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT |
|
|
model = models.segmentation.deeplabv3_mobilenet_v3_large(weights=weights) |
|
|
except: |
|
|
model = models.segmentation.deeplabv3_mobilenet_v3_large(pretrained=True) |
|
|
elif model_name == "LRASPP (MobileNetV3)": |
|
|
try: |
|
|
weights = models.segmentation.LRASPP_MobileNet_V3_Large_Weights.DEFAULT |
|
|
model = models.segmentation.lraspp_mobilenet_v3_large(weights=weights) |
|
|
except: |
|
|
model = models.segmentation.lraspp_mobilenet_v3_large(pretrained=True) |
|
|
|
|
|
model.eval() |
|
|
if torch.cuda.is_available(): |
|
|
model = model.cuda() |
|
|
return model |
|
|
|
|
|
|
|
|
current_model = load_model("DeepLabV3+ (ResNet50)") |
|
|
|
|
|
|
|
|
preprocess = transforms.Compose([ |
|
|
transforms.Resize(520), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
|
|
|
def get_palette(): |
|
|
return [ |
|
|
(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), |
|
|
(0, 0, 128), (128, 0, 128), (0, 128, 128), (128, 128, 128), |
|
|
(64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), |
|
|
(64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128), |
|
|
(0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), |
|
|
(0, 64, 128) |
|
|
] |
|
|
|
|
|
def colorize_mask(mask): |
|
|
|
|
|
palette = get_palette() |
|
|
h, w = mask.shape |
|
|
color_mask = np.zeros((h, w, 3), dtype=np.uint8) |
|
|
for label, color in enumerate(palette): |
|
|
color_mask[mask == label] = color |
|
|
return color_mask |
|
|
|
|
|
|
|
|
def predict(image, model_selection): |
|
|
if image is None: |
|
|
return None, None |
|
|
|
|
|
|
|
|
|
|
|
model = load_model(model_selection) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
input_tensor = preprocess(image).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(input_tensor)['out'] |
|
|
probs = F.softmax(output, dim=1) |
|
|
|
|
|
|
|
|
argmax_pred = torch.argmax(probs, dim=1).squeeze().cpu().numpy() |
|
|
argmax_vis = colorize_mask(argmax_pred) |
|
|
|
|
|
|
|
|
|
|
|
rankseg = RankSEG(metric='dice', solver='RMA') |
|
|
rankseg_pred_tensor = rankseg.predict(probs) |
|
|
rankseg_pred = rankseg_pred_tensor.squeeze().cpu().numpy() |
|
|
rankseg_vis = colorize_mask(rankseg_pred) |
|
|
|
|
|
return argmax_vis, rankseg_vis |
|
|
|
|
|
|
|
|
title = "π§© RankSEG: Optimize Segmentation Metrics without Retraining" |
|
|
description = """ |
|
|
<div style="text-align: center;"> |
|
|
<p> |
|
|
<b>RankSEG</b> is a state-of-the-art <b>test-time optimization</b> framework for segmentation. |
|
|
<br> |
|
|
It replaces the standard <code>argmax</code> with a rigorous ranking-based solver to directly maximize <b>Dice</b> or <b>IoU</b> metrics during inference. |
|
|
<br> |
|
|
β
<b>Plug-and-Play</b>: Works with ANY pre-trained model (DeepLab, UNet, SegFormer, etc.) |
|
|
<br> |
|
|
β
<b>No Retraining</b>: Zero-shot performance boost |
|
|
<br> |
|
|
β
<b>Versatile</b>: Supports Binary, Multi-class, and Multi-label tasks |
|
|
</p> |
|
|
<div style="display: flex; justify-content: center; gap: 10px; margin-top: 10px;"> |
|
|
<a href="https://github.com/rankseg/rankseg" target="_blank">[GitHub]</a> |
|
|
<a href="https://rankseg.readthedocs.io/" target="_blank">[Documentation]</a> |
|
|
</div> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
article = """ |
|
|
### How it works |
|
|
1. **Input**: The pre-trained model outputs probability maps. |
|
|
2. **Optimization**: Instead of standard `argmax`, RankSEG uses a ranking-based solver (RMA) to select pixels that maximize the global Dice/IoU score. |
|
|
3. **Result**: You typically get cleaner boundaries and better object recovery, consistent with the target metric. |
|
|
""" |
|
|
|
|
|
|
|
|
examples = [ |
|
|
["demo1.jpg", "DeepLabV3+ (ResNet50)"], |
|
|
["demo2.png", "LRASPP (MobileNetV3)"] |
|
|
] |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=[ |
|
|
gr.Image(type="pil", label="Input Image"), |
|
|
gr.Dropdown( |
|
|
choices=["DeepLabV3+ (ResNet50)", "DeepLabV3+ (ResNet101)", "DeepLabV3+ (MobileNetV3)", "LRASPP (MobileNetV3)"], |
|
|
value="DeepLabV3+ (ResNet50)", |
|
|
label="Select Pre-trained Model" |
|
|
) |
|
|
], |
|
|
outputs=[ |
|
|
gr.Image(label="Standard Argmax Prediction"), |
|
|
gr.Image(label="RankSEG Optimized (Dice)") |
|
|
], |
|
|
title=title, |
|
|
description=description, |
|
|
article=article, |
|
|
examples=examples, |
|
|
cache_examples=False |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|