put hyper-parameters
Browse files
app.py
CHANGED
|
@@ -22,39 +22,79 @@ maskclip = MaskClip().to(device)
|
|
| 22 |
dino = DINO().to(device)
|
| 23 |
to_torch_tensor = T.Compose([T.Resize(size=448, max_size=2048), T.ToTensor()])
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
@spaces.GPU
|
| 26 |
-
def segment_image(img: PIL.Image.Image, classnames: str, use_lposs_plus: bool | None
|
|
|
|
| 27 |
img_tensor = to_torch_tensor(PIL.Image.fromarray(img)).unsqueeze(0).to(device)
|
| 28 |
classnames = [c.strip() for c in classnames.split(",")]
|
| 29 |
num_classes = len(classnames)
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
preds = lposs(maskclip, dino, img_tensor, classnames)
|
| 32 |
if use_lposs_plus:
|
| 33 |
-
preds = lposs_plus(img_tensor, preds)
|
| 34 |
preds = F.interpolate(preds, size=img.shape[:-1], mode="bilinear", align_corners=False)
|
| 35 |
preds = F.softmax(preds * 100, dim=1).cpu().numpy()
|
| 36 |
return (img, [(preds[0, i, :, :], classnames[i]) for i in range(num_classes)])
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
gr.Image(label="Input Image"),
|
| 42 |
-
gr.Textbox(label="Class Names", info="Separate class names with commas"),
|
| 43 |
-
gr.Checkbox(label="Use LPOSS+", info="Enable pixel-level refinement using LPOSS+")
|
| 44 |
-
],
|
| 45 |
-
outputs=[
|
| 46 |
-
gr.AnnotatedImage(label="Segmentation Results")
|
| 47 |
-
],
|
| 48 |
-
title="LPOSS: Label Propagation Over Patches and Pixels for Open-vocabulary Semantic Segmentation",
|
| 49 |
-
article="""<div align='center' style='margin: 1em 0;'>
|
| 50 |
<a href='http://arxiv.org/abs/2503.19777' target='_blank' style='margin-right: 2em; text-decoration: none; font-weight: bold;'>
|
| 51 |
π arXiv
|
| 52 |
</a>
|
| 53 |
<a href='https://github.com/vladan-stojnic/LPOSS' target='_blank' style='text-decoration: none; font-weight: bold;'>
|
| 54 |
π» GitHub
|
| 55 |
</a>
|
| 56 |
-
</div>"""
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
demo.launch()
|
|
|
|
| 22 |
dino = DINO().to(device)
|
| 23 |
to_torch_tensor = T.Compose([T.Resize(size=448, max_size=2048), T.ToTensor()])
|
| 24 |
|
| 25 |
+
# Default hyperparameter values
|
| 26 |
+
DEFAULT_SIGMA = 100
|
| 27 |
+
DEFAULT_ALPHA = 0.95
|
| 28 |
+
DEFAULT_K = 400
|
| 29 |
+
DEFAULT_WSIZE = 224
|
| 30 |
+
DEFAULT_GAMMA = 3.0
|
| 31 |
+
DEFAULT_TAU = 0.01
|
| 32 |
+
|
| 33 |
+
# Function to reset hyperparameters to default values
|
| 34 |
+
def reset_hyperparams():
|
| 35 |
+
return DEFAULT_WSIZE, DEFAULT_K, DEFAULT_GAMMA, DEFAULT_ALPHA, DEFAULT_SIGMA, DEFAULT_TAU
|
| 36 |
+
|
| 37 |
@spaces.GPU
|
| 38 |
+
def segment_image(img: PIL.Image.Image, classnames: str, use_lposs_plus: bool | None,
|
| 39 |
+
winodw_size:int, k:int, gamma:float, alpha:float, sigma: float, tau:float) -> tuple[np.ndarray | PIL.Image.Image | str, list[tuple[np.ndarray | tuple[int, int, int, int], str]]]:
|
| 40 |
img_tensor = to_torch_tensor(PIL.Image.fromarray(img)).unsqueeze(0).to(device)
|
| 41 |
classnames = [c.strip() for c in classnames.split(",")]
|
| 42 |
num_classes = len(classnames)
|
| 43 |
+
|
| 44 |
+
winodw_size = (winodw_size, winodw_size)
|
| 45 |
+
stride = (winodw_size[0] // 2, winodw_size[1] // 2)
|
| 46 |
|
| 47 |
+
preds = lposs(maskclip, dino, img_tensor, classnames, window_size=winodw_size, window_stride=stride, sigma=1/sigma, lp_k_image=k, lp_gamma=gamma, lp_alpha=alpha)
|
| 48 |
if use_lposs_plus:
|
| 49 |
+
preds = lposs_plus(img_tensor, preds, tau=tau, alpha=alpha)
|
| 50 |
preds = F.interpolate(preds, size=img.shape[:-1], mode="bilinear", align_corners=False)
|
| 51 |
preds = F.softmax(preds * 100, dim=1).cpu().numpy()
|
| 52 |
return (img, [(preds[0, i, :, :], classnames[i]) for i in range(num_classes)])
|
| 53 |
|
| 54 |
+
with gr.Blocks() as demo:
|
| 55 |
+
gr.Markdown("# LPOSS: Label Propagation Over Patches and Pixels for Open-vocabulary Semantic Segmentation")
|
| 56 |
+
gr.Markdown("""<div align='center' style='margin: 1em 0;'>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
<a href='http://arxiv.org/abs/2503.19777' target='_blank' style='margin-right: 2em; text-decoration: none; font-weight: bold;'>
|
| 58 |
π arXiv
|
| 59 |
</a>
|
| 60 |
<a href='https://github.com/vladan-stojnic/LPOSS' target='_blank' style='text-decoration: none; font-weight: bold;'>
|
| 61 |
π» GitHub
|
| 62 |
</a>
|
| 63 |
+
</div>""")
|
| 64 |
+
gr.Markdown("Upload an image and specify the objects you want to segment by listing their names separated by commas.")
|
| 65 |
+
|
| 66 |
+
with gr.Row(variant="panel"):
|
| 67 |
+
with gr.Column(scale=1):
|
| 68 |
+
with gr.Row():
|
| 69 |
+
gr.Markdown("Hyper-parameters")
|
| 70 |
+
with gr.Row():
|
| 71 |
+
window_size = gr.Slider(minimum=112, maximum=448, value=DEFAULT_WSIZE, step=16, label="Window Size")
|
| 72 |
+
k = gr.Slider(minimum=50, maximum=800, value=DEFAULT_K, step=50, label="k")
|
| 73 |
+
gamma = gr.Slider(minimum=0.0, maximum=10.0, value=DEFAULT_GAMMA, step=0.5, label="Gamma")
|
| 74 |
+
alpha = gr.Slider(minimum=0.0, maximum=1.0, value=DEFAULT_ALPHA, step=0.05, label="Alpha")
|
| 75 |
+
sigma = gr.Slider(minimum=50, maximum=400, value=DEFAULT_SIGMA, step=10, label="Sigma")
|
| 76 |
+
tau = gr.Slider(minimum=0.0, maximum=1.0, value=DEFAULT_TAU, step=0.01, label="Tau")
|
| 77 |
+
with gr.Row():
|
| 78 |
+
reset_btn = gr.Button("Reset to Default Values")
|
| 79 |
+
|
| 80 |
+
with gr.Row():
|
| 81 |
+
with gr.Column(scale=2):
|
| 82 |
+
input_image = gr.Image(label="Input Image")
|
| 83 |
+
class_names = gr.Textbox(label="Class Names", info="Separate class names with commas")
|
| 84 |
+
use_lposs_plus = gr.Checkbox(label="Use LPOSS+", info="Enable pixel-level refinement using LPOSS+")
|
| 85 |
+
|
| 86 |
+
with gr.Column(scale=3):
|
| 87 |
+
output_image = gr.AnnotatedImage(label="Segmentation Results")
|
| 88 |
+
|
| 89 |
+
with gr.Row():
|
| 90 |
+
segment_btn = gr.Button("Segment Image")
|
| 91 |
+
|
| 92 |
+
reset_btn.click(fn=reset_hyperparams, outputs=[window_size, k, gamma, alpha, sigma, tau])
|
| 93 |
+
|
| 94 |
+
segment_btn.click(
|
| 95 |
+
fn=segment_image,
|
| 96 |
+
inputs=[input_image, class_names, use_lposs_plus, window_size, k, gamma, alpha, sigma, tau],
|
| 97 |
+
outputs=[output_image]
|
| 98 |
+
)
|
| 99 |
|
| 100 |
demo.launch()
|