statmlben commited on
Commit
573b083
·
verified ·
1 Parent(s): 5094272

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +101 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision import models, transforms
5
+ from PIL import Image
6
+ from rankseg import RankSEG
7
+ import numpy as np
8
+
9
+ # 1. Load Model (Cache it to avoid reloading)
10
+ # We use a lightweight DeepLabV3+ (MobileNet) for speed in the demo,
11
+ # or ResNet50 if we want better quality. Let's use ResNet50.
12
+ def load_model():
13
+ try:
14
+ weights = models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT
15
+ model = models.segmentation.deeplabv3_resnet50(weights=weights)
16
+ except:
17
+ model = models.segmentation.deeplabv3_resnet50(pretrained=True)
18
+ model.eval()
19
+ return model
20
+
21
+ model = load_model()
22
+
23
+ # 2. Define Transformations
24
+ preprocess = transforms.Compose([
25
+ transforms.Resize(520),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
28
+ ])
29
+
30
+ # Color palette for visualization (PASCAL VOC style)
31
+ def get_palette():
32
+ return [
33
+ (0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0),
34
+ (0, 0, 128), (128, 0, 128), (0, 128, 128), (128, 128, 128),
35
+ (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0),
36
+ (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
37
+ (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0),
38
+ (0, 64, 128)
39
+ ]
40
+
41
+ def colorize_mask(mask):
42
+ # mask: (H, W) numpy array
43
+ palette = get_palette()
44
+ h, w = mask.shape
45
+ color_mask = np.zeros((h, w, 3), dtype=np.uint8)
46
+ for label, color in enumerate(palette):
47
+ color_mask[mask == label] = color
48
+ return color_mask
49
+
50
+ # 3. Inference Function
51
+ def predict(image):
52
+ if image is None:
53
+ return None, None
54
+
55
+ # Preprocess
56
+ input_tensor = preprocess(image).unsqueeze(0)
57
+
58
+ with torch.no_grad():
59
+ output = model(input_tensor)['out']
60
+ probs = F.softmax(output, dim=1) # (1, 21, H, W)
61
+
62
+ # --- METHOD 1: ARGMAX ---
63
+ argmax_pred = torch.argmax(probs, dim=1).squeeze().numpy()
64
+ argmax_vis = colorize_mask(argmax_pred)
65
+
66
+ # --- METHOD 2: RANKSEG ---
67
+ # Optimize for Dice
68
+ rankseg = RankSEG(metric='dice', solver='RMA')
69
+ rankseg_pred_tensor = rankseg.predict(probs)
70
+ rankseg_pred = rankseg_pred_tensor.squeeze().numpy()
71
+ rankseg_vis = colorize_mask(rankseg_pred)
72
+
73
+ return argmax_vis, rankseg_vis
74
+
75
+ # 4. Gradio Interface
76
+ title = "🧩 RankSEG: Optimize Segmentation Metrics without Retraining"
77
+ description = """
78
+ **RankSEG** (NeurIPS 2025) is a plug-and-play module that improves segmentation results by directly optimizing for Dice/IoU metrics during inference.
79
+ Upload an image to see how RankSEG refines the mask compared to standard Argmax.
80
+ """
81
+
82
+ # examples = [
83
+ # ["example1.jpg"], # You will need to upload example images to the HF Space
84
+ # ["example2.jpg"]
85
+ # ]
86
+
87
+ demo = gr.Interface(
88
+ fn=predict,
89
+ inputs=gr.Image(type="pil", label="Input Image"),
90
+ outputs=[
91
+ gr.Image(label="Standard Argmax Prediction"),
92
+ gr.Image(label="RankSEG Optimized (Dice)")
93
+ ],
94
+ title=title,
95
+ description=description,
96
+ # examples=examples, # Uncomment if you have examples
97
+ cache_examples=False
98
+ )
99
+
100
+ if __name__ == "__main__":
101
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch
3
+ torchvision
4
+ pillow
5
+ numpy
6
+ rankseg