Spaces:
Running
on
Zero
Running
on
Zero
add aligned+recursion
Browse files
app.py
CHANGED
|
@@ -419,9 +419,22 @@ def ncut_run(
|
|
| 419 |
metric="cosine" if i == 0 else recursion_metric,
|
| 420 |
)
|
| 421 |
logging_str += _logging_str
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
if recursion_metric == "cosine":
|
| 426 |
inp = F.normalize(inp, dim=-1)
|
| 427 |
return rgbs[0], rgbs[1], rgbs[2], logging_str
|
|
@@ -472,15 +485,12 @@ def ncut_run(
|
|
| 472 |
|
| 473 |
if "AlignedThreeModelAttnNodes" == model_name:
|
| 474 |
# dirty patch for the alignedcut paper
|
| 475 |
-
|
| 476 |
-
# for i_node in range(rgb.shape[1]):
|
| 477 |
-
# _rgb = rgb[:, i_node]
|
| 478 |
-
# galleries.append(to_pil_images(_rgb, target_size=56))
|
| 479 |
-
# return *galleries, logging_str
|
| 480 |
pil_images = []
|
| 481 |
for i_image in range(rgb.shape[0]):
|
| 482 |
_im = plot_one_image_36_grid(images[i_image], rgb[i_image])
|
| 483 |
pil_images.append(_im)
|
|
|
|
| 484 |
return pil_images, logging_str
|
| 485 |
|
| 486 |
|
|
@@ -516,26 +526,26 @@ def ncut_run(
|
|
| 516 |
|
| 517 |
def _ncut_run(*args, **kwargs):
|
| 518 |
n_ret = kwargs.pop("n_ret", 1)
|
| 519 |
-
try:
|
| 520 |
-
|
| 521 |
-
|
| 522 |
|
| 523 |
-
|
| 524 |
|
| 525 |
-
|
| 526 |
-
|
| 527 |
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
except Exception as e:
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
|
| 540 |
if USE_HUGGINGFACE_ZEROGPU:
|
| 541 |
@spaces.GPU(duration=20)
|
|
@@ -1018,19 +1028,28 @@ def make_dataset_images_section(advanced=False, is_random=False):
|
|
| 1018 |
return dataset_dropdown, num_images_slider, random_seed_slider, load_images_button
|
| 1019 |
|
| 1020 |
|
| 1021 |
-
def random_rotate_rgb_gallery(images):
|
| 1022 |
-
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
-
|
| 1032 |
-
|
| 1033 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1034 |
|
| 1035 |
def sequence_rotate_rgb_gallery(images):
|
| 1036 |
if images is None or len(images) == 0:
|
|
@@ -1041,8 +1060,10 @@ def sequence_rotate_rgb_gallery(images):
|
|
| 1041 |
images = [np.array(image).astype(np.float32) for image in images]
|
| 1042 |
images = np.stack(images)
|
| 1043 |
images = torch.tensor(images) / 255
|
|
|
|
| 1044 |
rotation_matrix = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).float()
|
| 1045 |
images = images @ rotation_matrix
|
|
|
|
| 1046 |
images = to_pil_images(images, resize=False)
|
| 1047 |
return images
|
| 1048 |
|
|
@@ -1055,7 +1076,9 @@ def flip_rgb_gallery(images, axis=0):
|
|
| 1055 |
images = [np.array(image).astype(np.float32) for image in images]
|
| 1056 |
images = np.stack(images)
|
| 1057 |
images = torch.tensor(images) / 255
|
|
|
|
| 1058 |
images = 1 - images
|
|
|
|
| 1059 |
images = to_pil_images(images, resize=False)
|
| 1060 |
return images
|
| 1061 |
|
|
@@ -1074,7 +1097,7 @@ def make_output_images_section():
|
|
| 1074 |
add_output_images_buttons(output_gallery)
|
| 1075 |
return output_gallery
|
| 1076 |
|
| 1077 |
-
def make_parameters_section(is_lisa=False):
|
| 1078 |
gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
|
| 1079 |
from ncut_pytorch.backbone import list_models, get_demo_model_names
|
| 1080 |
model_names = list_models()
|
|
@@ -1095,7 +1118,7 @@ def make_parameters_section(is_lisa=False):
|
|
| 1095 |
negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False)
|
| 1096 |
node_type_dropdown = gr.Dropdown(layer_names, label="LISA (SAM) decoder: Layer and Node", value="dec_1_block", elem_id="node_type")
|
| 1097 |
else:
|
| 1098 |
-
model_radio = gr.Radio(["CLIP", "DiNO", "Diffusion", "ImageNet", "MAE", "SAM"], label="Backbone", value="DiNO", elem_id="model_radio", show_label=True)
|
| 1099 |
model_dropdown = gr.Dropdown(get_filtered_model_names("DiNO"), label="", value="DiNO(dino_vitb8_448)", elem_id="model_name", show_label=False)
|
| 1100 |
model_radio.change(fn=lambda x: gr.update(choices=get_filtered_model_names(x), value=get_default_model_name(x)), inputs=model_radio, outputs=[model_dropdown])
|
| 1101 |
layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
|
|
@@ -1292,6 +1315,7 @@ with demo:
|
|
| 1292 |
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True)
|
| 1293 |
num_images_slider.value = 100
|
| 1294 |
clear_images_button.visible = False
|
|
|
|
| 1295 |
|
| 1296 |
with gr.Column(scale=5, min_width=200):
|
| 1297 |
with gr.Accordion("➡️ Recursion config", open=True):
|
|
@@ -1311,14 +1335,6 @@ with demo:
|
|
| 1311 |
] = make_parameters_section()
|
| 1312 |
num_eig_slider.visible = False
|
| 1313 |
affinity_focal_gamma_slider.visible = False
|
| 1314 |
-
# logging text box
|
| 1315 |
-
with gr.Row():
|
| 1316 |
-
with gr.Column(scale=5, min_width=200):
|
| 1317 |
-
gr.Markdown(' ')
|
| 1318 |
-
with gr.Column(scale=5, min_width=200):
|
| 1319 |
-
gr.Markdown(' ')
|
| 1320 |
-
with gr.Column(scale=5, min_width=200):
|
| 1321 |
-
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
| 1322 |
true_placeholder = gr.Checkbox(label="True placeholder", value=True, elem_id="true_placeholder")
|
| 1323 |
true_placeholder.visible = False
|
| 1324 |
false_placeholder = gr.Checkbox(label="False placeholder", value=False, elem_id="false_placeholder")
|
|
@@ -1326,13 +1342,12 @@ with demo:
|
|
| 1326 |
number_placeholder = gr.Number(0, label="Number placeholder", elem_id="number_placeholder")
|
| 1327 |
number_placeholder.visible = False
|
| 1328 |
clear_images_button.click(lambda x: ([],), outputs=[input_gallery])
|
| 1329 |
-
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
| 1330 |
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
| 1331 |
|
| 1332 |
submit_button.click(
|
| 1333 |
partial(run_fn, n_ret=3),
|
| 1334 |
inputs=[
|
| 1335 |
-
input_gallery, model_dropdown, layer_slider,
|
| 1336 |
positive_prompt, negative_prompt,
|
| 1337 |
false_placeholder, no_prompt, no_prompt, no_prompt,
|
| 1338 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
|
@@ -1457,7 +1472,6 @@ with demo:
|
|
| 1457 |
with gr.Row():
|
| 1458 |
with gr.Column(scale=5, min_width=200):
|
| 1459 |
input_gallery, submit_button, clear_images_button = make_input_images_section()
|
| 1460 |
-
|
| 1461 |
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True, is_random=True)
|
| 1462 |
num_images_slider.value = 100
|
| 1463 |
|
|
@@ -1476,7 +1490,7 @@ with demo:
|
|
| 1476 |
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 1477 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
| 1478 |
sampling_method_dropdown, positive_prompt, negative_prompt
|
| 1479 |
-
] = make_parameters_section()
|
| 1480 |
model_dropdown.value = "AlignedThreeModelAttnNodes"
|
| 1481 |
model_dropdown.visible = False
|
| 1482 |
layer_slider.visible = False
|
|
@@ -1505,7 +1519,7 @@ with demo:
|
|
| 1505 |
outputs=[output_gallery, logging_text],
|
| 1506 |
)
|
| 1507 |
|
| 1508 |
-
with gr.Tab('Model Aligned (+
|
| 1509 |
gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
|
| 1510 |
gr.Markdown('---')
|
| 1511 |
gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
|
|
@@ -1513,29 +1527,58 @@ with demo:
|
|
| 1513 |
gr.Markdown('')
|
| 1514 |
gr.Markdown("To see a good pattern, you will need to load 100~1000 images. 100 images need 10sec for RTX4090. Running out of HuggingFace GPU Quota? Try [Demo](https://ncut-pytorch.readthedocs.io/en/latest/demo/) hosted at UPenn")
|
| 1515 |
gr.Markdown('---')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1516 |
with gr.Row():
|
| 1517 |
with gr.Column(scale=5, min_width=200):
|
| 1518 |
input_gallery, submit_button, clear_images_button = make_input_images_section()
|
| 1519 |
-
|
| 1520 |
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True, is_random=True)
|
| 1521 |
num_images_slider.value = 100
|
| 1522 |
|
| 1523 |
|
| 1524 |
with gr.Column(scale=5, min_width=200):
|
| 1525 |
-
|
| 1526 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1527 |
gr.Markdown('---')
|
| 1528 |
gr.Markdown('Model: CLIP(ViT-B-16/openai), DiNOv2reg(dinov2_vitb14_reg), MAE(vit_base)')
|
| 1529 |
gr.Markdown('Layer type: attention output (attn), without sum of residual')
|
| 1530 |
-
gr.Markdown('### TIP2: for large image set, please increase the `num_sample` for t-SNE and NCUT')
|
| 1531 |
-
gr.Markdown('---')
|
| 1532 |
[
|
| 1533 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
| 1534 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
| 1535 |
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 1536 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
| 1537 |
sampling_method_dropdown, positive_prompt, negative_prompt
|
| 1538 |
-
] = make_parameters_section()
|
|
|
|
|
|
|
| 1539 |
model_dropdown.value = "AlignedThreeModelAttnNodes"
|
| 1540 |
model_dropdown.visible = False
|
| 1541 |
layer_slider.visible = False
|
|
@@ -1545,23 +1588,30 @@ with demo:
|
|
| 1545 |
# logging text box
|
| 1546 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
| 1547 |
|
| 1548 |
-
clear_images_button.click(lambda x: ([],
|
| 1549 |
|
| 1550 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1551 |
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
| 1552 |
|
| 1553 |
submit_button.click(
|
| 1554 |
-
run_fn,
|
| 1555 |
inputs=[
|
| 1556 |
-
input_gallery, model_dropdown, layer_slider,
|
| 1557 |
positive_prompt, negative_prompt,
|
| 1558 |
false_placeholder, no_prompt, no_prompt, no_prompt,
|
| 1559 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
| 1560 |
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 1561 |
-
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
|
|
|
|
|
|
|
|
|
|
| 1562 |
],
|
| 1563 |
-
|
| 1564 |
-
outputs=[output_gallery, logging_text],
|
| 1565 |
)
|
| 1566 |
|
| 1567 |
|
|
|
|
| 419 |
metric="cosine" if i == 0 else recursion_metric,
|
| 420 |
)
|
| 421 |
logging_str += _logging_str
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
if "AlignedThreeModelAttnNodes" == model_name:
|
| 425 |
+
# dirty patch for the alignedcut paper
|
| 426 |
+
start = time.time()
|
| 427 |
+
pil_images = []
|
| 428 |
+
for i_image in range(rgb.shape[0]):
|
| 429 |
+
_im = plot_one_image_36_grid(images[i_image], rgb[i_image])
|
| 430 |
+
pil_images.append(_im)
|
| 431 |
+
rgbs.append(pil_images)
|
| 432 |
+
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
| 433 |
+
else:
|
| 434 |
+
rgb = dont_use_too_much_green(rgb)
|
| 435 |
+
rgbs.append(to_pil_images(rgb))
|
| 436 |
+
|
| 437 |
+
inp = eigvecs.reshape(*features.shape[:-1], -1)
|
| 438 |
if recursion_metric == "cosine":
|
| 439 |
inp = F.normalize(inp, dim=-1)
|
| 440 |
return rgbs[0], rgbs[1], rgbs[2], logging_str
|
|
|
|
| 485 |
|
| 486 |
if "AlignedThreeModelAttnNodes" == model_name:
|
| 487 |
# dirty patch for the alignedcut paper
|
| 488 |
+
start = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
pil_images = []
|
| 490 |
for i_image in range(rgb.shape[0]):
|
| 491 |
_im = plot_one_image_36_grid(images[i_image], rgb[i_image])
|
| 492 |
pil_images.append(_im)
|
| 493 |
+
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
| 494 |
return pil_images, logging_str
|
| 495 |
|
| 496 |
|
|
|
|
| 526 |
|
| 527 |
def _ncut_run(*args, **kwargs):
|
| 528 |
n_ret = kwargs.pop("n_ret", 1)
|
| 529 |
+
# try:
|
| 530 |
+
# if torch.cuda.is_available():
|
| 531 |
+
# torch.cuda.empty_cache()
|
| 532 |
|
| 533 |
+
# ret = ncut_run(*args, **kwargs)
|
| 534 |
|
| 535 |
+
# if torch.cuda.is_available():
|
| 536 |
+
# torch.cuda.empty_cache()
|
| 537 |
|
| 538 |
+
# ret = list(ret)[:n_ret] + [ret[-1]]
|
| 539 |
+
# return ret
|
| 540 |
+
# except Exception as e:
|
| 541 |
+
# gr.Error(str(e))
|
| 542 |
+
# if torch.cuda.is_available():
|
| 543 |
+
# torch.cuda.empty_cache()
|
| 544 |
+
# return *(None for _ in range(n_ret)), "Error: " + str(e)
|
| 545 |
+
|
| 546 |
+
ret = ncut_run(*args, **kwargs)
|
| 547 |
+
ret = list(ret)[:n_ret] + [ret[-1]]
|
| 548 |
+
return ret
|
| 549 |
|
| 550 |
if USE_HUGGINGFACE_ZEROGPU:
|
| 551 |
@spaces.GPU(duration=20)
|
|
|
|
| 1028 |
return dataset_dropdown, num_images_slider, random_seed_slider, load_images_button
|
| 1029 |
|
| 1030 |
|
| 1031 |
+
# def random_rotate_rgb_gallery(images):
|
| 1032 |
+
# if images is None or len(images) == 0:
|
| 1033 |
+
# gr.Warning("No images selected.")
|
| 1034 |
+
# return []
|
| 1035 |
+
# # read webp images
|
| 1036 |
+
# images = [Image.open(image[0]).convert("RGB") for image in images]
|
| 1037 |
+
# images = [np.array(image).astype(np.float32) for image in images]
|
| 1038 |
+
# images = np.stack(images)
|
| 1039 |
+
# images = torch.tensor(images) / 255
|
| 1040 |
+
# position = np.random.choice([1, 2, 4, 5, 6])
|
| 1041 |
+
# images = rotate_rgb_cube(images, position)
|
| 1042 |
+
# images = to_pil_images(images, resize=False)
|
| 1043 |
+
# return images
|
| 1044 |
+
|
| 1045 |
+
def protect_original_image_in_plot(original_image, rotated_images):
|
| 1046 |
+
plot_h, plot_w = 332, 1542
|
| 1047 |
+
image_h, image_w = original_image.shape[1], original_image.shape[2]
|
| 1048 |
+
if not (plot_h == image_h and plot_w == image_w):
|
| 1049 |
+
return rotated_images
|
| 1050 |
+
protection_w = 190
|
| 1051 |
+
rotated_images[:, :, :protection_w] = original_image[:, :, :protection_w]
|
| 1052 |
+
return rotated_images
|
| 1053 |
|
| 1054 |
def sequence_rotate_rgb_gallery(images):
|
| 1055 |
if images is None or len(images) == 0:
|
|
|
|
| 1060 |
images = [np.array(image).astype(np.float32) for image in images]
|
| 1061 |
images = np.stack(images)
|
| 1062 |
images = torch.tensor(images) / 255
|
| 1063 |
+
original_images = images.clone()
|
| 1064 |
rotation_matrix = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).float()
|
| 1065 |
images = images @ rotation_matrix
|
| 1066 |
+
images = protect_original_image_in_plot(original_images, images)
|
| 1067 |
images = to_pil_images(images, resize=False)
|
| 1068 |
return images
|
| 1069 |
|
|
|
|
| 1076 |
images = [np.array(image).astype(np.float32) for image in images]
|
| 1077 |
images = np.stack(images)
|
| 1078 |
images = torch.tensor(images) / 255
|
| 1079 |
+
original_images = images.clone()
|
| 1080 |
images = 1 - images
|
| 1081 |
+
images = protect_original_image_in_plot(original_images, images)
|
| 1082 |
images = to_pil_images(images, resize=False)
|
| 1083 |
return images
|
| 1084 |
|
|
|
|
| 1097 |
add_output_images_buttons(output_gallery)
|
| 1098 |
return output_gallery
|
| 1099 |
|
| 1100 |
+
def make_parameters_section(is_lisa=False, model_ratio=True):
|
| 1101 |
gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
|
| 1102 |
from ncut_pytorch.backbone import list_models, get_demo_model_names
|
| 1103 |
model_names = list_models()
|
|
|
|
| 1118 |
negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False)
|
| 1119 |
node_type_dropdown = gr.Dropdown(layer_names, label="LISA (SAM) decoder: Layer and Node", value="dec_1_block", elem_id="node_type")
|
| 1120 |
else:
|
| 1121 |
+
model_radio = gr.Radio(["CLIP", "DiNO", "Diffusion", "ImageNet", "MAE", "SAM"], label="Backbone", value="DiNO", elem_id="model_radio", show_label=True, visible=model_ratio)
|
| 1122 |
model_dropdown = gr.Dropdown(get_filtered_model_names("DiNO"), label="", value="DiNO(dino_vitb8_448)", elem_id="model_name", show_label=False)
|
| 1123 |
model_radio.change(fn=lambda x: gr.update(choices=get_filtered_model_names(x), value=get_default_model_name(x)), inputs=model_radio, outputs=[model_dropdown])
|
| 1124 |
layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
|
|
|
|
| 1315 |
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True)
|
| 1316 |
num_images_slider.value = 100
|
| 1317 |
clear_images_button.visible = False
|
| 1318 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
| 1319 |
|
| 1320 |
with gr.Column(scale=5, min_width=200):
|
| 1321 |
with gr.Accordion("➡️ Recursion config", open=True):
|
|
|
|
| 1335 |
] = make_parameters_section()
|
| 1336 |
num_eig_slider.visible = False
|
| 1337 |
affinity_focal_gamma_slider.visible = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1338 |
true_placeholder = gr.Checkbox(label="True placeholder", value=True, elem_id="true_placeholder")
|
| 1339 |
true_placeholder.visible = False
|
| 1340 |
false_placeholder = gr.Checkbox(label="False placeholder", value=False, elem_id="false_placeholder")
|
|
|
|
| 1342 |
number_placeholder = gr.Number(0, label="Number placeholder", elem_id="number_placeholder")
|
| 1343 |
number_placeholder.visible = False
|
| 1344 |
clear_images_button.click(lambda x: ([],), outputs=[input_gallery])
|
|
|
|
| 1345 |
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
| 1346 |
|
| 1347 |
submit_button.click(
|
| 1348 |
partial(run_fn, n_ret=3),
|
| 1349 |
inputs=[
|
| 1350 |
+
input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown,
|
| 1351 |
positive_prompt, negative_prompt,
|
| 1352 |
false_placeholder, no_prompt, no_prompt, no_prompt,
|
| 1353 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
|
|
|
| 1472 |
with gr.Row():
|
| 1473 |
with gr.Column(scale=5, min_width=200):
|
| 1474 |
input_gallery, submit_button, clear_images_button = make_input_images_section()
|
|
|
|
| 1475 |
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True, is_random=True)
|
| 1476 |
num_images_slider.value = 100
|
| 1477 |
|
|
|
|
| 1490 |
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 1491 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
| 1492 |
sampling_method_dropdown, positive_prompt, negative_prompt
|
| 1493 |
+
] = make_parameters_section(model_ratio=False)
|
| 1494 |
model_dropdown.value = "AlignedThreeModelAttnNodes"
|
| 1495 |
model_dropdown.visible = False
|
| 1496 |
layer_slider.visible = False
|
|
|
|
| 1519 |
outputs=[output_gallery, logging_text],
|
| 1520 |
)
|
| 1521 |
|
| 1522 |
+
with gr.Tab('Model Aligned (+Rrecursion)'):
|
| 1523 |
gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
|
| 1524 |
gr.Markdown('---')
|
| 1525 |
gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
|
|
|
|
| 1527 |
gr.Markdown('')
|
| 1528 |
gr.Markdown("To see a good pattern, you will need to load 100~1000 images. 100 images need 10sec for RTX4090. Running out of HuggingFace GPU Quota? Try [Demo](https://ncut-pytorch.readthedocs.io/en/latest/demo/) hosted at UPenn")
|
| 1529 |
gr.Markdown('---')
|
| 1530 |
+
|
| 1531 |
+
# with gr.Row():
|
| 1532 |
+
# with gr.Column(scale=5, min_width=200):
|
| 1533 |
+
# gr.Markdown('### Output (Recursion #1)')
|
| 1534 |
+
# l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=False, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
| 1535 |
+
# add_output_images_buttons(l1_gallery)
|
| 1536 |
+
# with gr.Column(scale=5, min_width=200):
|
| 1537 |
+
# gr.Markdown('### Output (Recursion #2)')
|
| 1538 |
+
# l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=False, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
| 1539 |
+
# add_output_images_buttons(l2_gallery)
|
| 1540 |
+
# with gr.Column(scale=5, min_width=200):
|
| 1541 |
+
# gr.Markdown('### Output (Recursion #3)')
|
| 1542 |
+
# l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=False, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
| 1543 |
+
# add_output_images_buttons(l3_gallery)
|
| 1544 |
+
gr.Markdown('### Output (Recursion #1)')
|
| 1545 |
+
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=False, elem_id="ncut_l1", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
|
| 1546 |
+
add_output_images_buttons(l1_gallery)
|
| 1547 |
+
gr.Markdown('### Output (Recursion #2)')
|
| 1548 |
+
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=False, elem_id="ncut_l2", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
|
| 1549 |
+
add_output_images_buttons(l2_gallery)
|
| 1550 |
+
gr.Markdown('### Output (Recursion #3)')
|
| 1551 |
+
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=False, elem_id="ncut_l3", columns=[100], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False, preview=True)
|
| 1552 |
+
add_output_images_buttons(l3_gallery)
|
| 1553 |
+
|
| 1554 |
with gr.Row():
|
| 1555 |
with gr.Column(scale=5, min_width=200):
|
| 1556 |
input_gallery, submit_button, clear_images_button = make_input_images_section()
|
|
|
|
| 1557 |
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True, is_random=True)
|
| 1558 |
num_images_slider.value = 100
|
| 1559 |
|
| 1560 |
|
| 1561 |
with gr.Column(scale=5, min_width=200):
|
| 1562 |
+
with gr.Accordion("➡️ Recursion config", open=True):
|
| 1563 |
+
l1_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #1: N eigenvectors", value=100, elem_id="l1_num_eig")
|
| 1564 |
+
l2_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #2: N eigenvectors", value=50, elem_id="l2_num_eig")
|
| 1565 |
+
l3_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #3: N eigenvectors", value=50, elem_id="l3_num_eig")
|
| 1566 |
+
metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="Recursion distance metric", value="cosine", elem_id="recursion_metric")
|
| 1567 |
+
l1_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #1: Affinity focal gamma", value=0.5, elem_id="recursion_l1_gamma")
|
| 1568 |
+
l2_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #2: Affinity focal gamma", value=0.5, elem_id="recursion_l2_gamma")
|
| 1569 |
+
l3_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #3: Affinity focal gamma", value=0.5, elem_id="recursion_l3_gamma")
|
| 1570 |
gr.Markdown('---')
|
| 1571 |
gr.Markdown('Model: CLIP(ViT-B-16/openai), DiNOv2reg(dinov2_vitb14_reg), MAE(vit_base)')
|
| 1572 |
gr.Markdown('Layer type: attention output (attn), without sum of residual')
|
|
|
|
|
|
|
| 1573 |
[
|
| 1574 |
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
| 1575 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
| 1576 |
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 1577 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
| 1578 |
sampling_method_dropdown, positive_prompt, negative_prompt
|
| 1579 |
+
] = make_parameters_section(model_ratio=False)
|
| 1580 |
+
num_eig_slider.visible = False
|
| 1581 |
+
affinity_focal_gamma_slider.visible = False
|
| 1582 |
model_dropdown.value = "AlignedThreeModelAttnNodes"
|
| 1583 |
model_dropdown.visible = False
|
| 1584 |
layer_slider.visible = False
|
|
|
|
| 1588 |
# logging text box
|
| 1589 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
| 1590 |
|
| 1591 |
+
clear_images_button.click(lambda x: ([],), outputs=[input_gallery])
|
| 1592 |
|
| 1593 |
+
true_placeholder = gr.Checkbox(label="True placeholder", value=True, elem_id="true_placeholder")
|
| 1594 |
+
true_placeholder.visible = False
|
| 1595 |
+
false_placeholder = gr.Checkbox(label="False placeholder", value=False, elem_id="false_placeholder")
|
| 1596 |
+
false_placeholder.visible = False
|
| 1597 |
+
number_placeholder = gr.Number(0, label="Number placeholder", elem_id="number_placeholder")
|
| 1598 |
+
number_placeholder.visible = False
|
| 1599 |
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
| 1600 |
|
| 1601 |
submit_button.click(
|
| 1602 |
+
partial(run_fn, n_ret=3),
|
| 1603 |
inputs=[
|
| 1604 |
+
input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown,
|
| 1605 |
positive_prompt, negative_prompt,
|
| 1606 |
false_placeholder, no_prompt, no_prompt, no_prompt,
|
| 1607 |
affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
|
| 1608 |
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 1609 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown,
|
| 1610 |
+
false_placeholder, number_placeholder, true_placeholder,
|
| 1611 |
+
l2_num_eig_slider, l3_num_eig_slider, metric_dropdown,
|
| 1612 |
+
l1_affinity_focal_gamma_slider, l2_affinity_focal_gamma_slider, l3_affinity_focal_gamma_slider
|
| 1613 |
],
|
| 1614 |
+
outputs=[l1_gallery, l2_gallery, l3_gallery, logging_text],
|
|
|
|
| 1615 |
)
|
| 1616 |
|
| 1617 |
|