Spaces:
Running
on
Zero
Running
on
Zero
add advanced tab for recursive
Browse files
app.py
CHANGED
|
@@ -325,9 +325,9 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
|
|
| 325 |
mask = mask[mask_sort_idx[:3]]
|
| 326 |
sort_values.append(mask.mean().item())
|
| 327 |
# fps_heatmaps[idx.item()] = heatmap.cpu()
|
| 328 |
-
fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:
|
| 329 |
top3_image_idx[idx.item()] = mask_sort_idx[:3]
|
| 330 |
-
top10_image_idx[idx.item()] = mask_sort_idx[:
|
| 331 |
# do the sorting
|
| 332 |
_sort_idx = torch.tensor(sort_values).argsort(descending=True)
|
| 333 |
fps_idx = fps_idx[_sort_idx]
|
|
@@ -427,6 +427,7 @@ def ncut_run(
|
|
| 427 |
alignedcut_eig_norm_plot=False,
|
| 428 |
**kwargs,
|
| 429 |
):
|
|
|
|
| 430 |
progress = gr.Progress()
|
| 431 |
progress(0.2, desc="Feature Extraction")
|
| 432 |
|
|
@@ -483,6 +484,7 @@ def ncut_run(
|
|
| 483 |
|
| 484 |
if recursion:
|
| 485 |
rgbs = []
|
|
|
|
| 486 |
recursion_gammas = [recursion_l1_gamma, recursion_l2_gamma, recursion_l3_gamma]
|
| 487 |
inp = features
|
| 488 |
progress_start = 0.4
|
|
@@ -509,6 +511,7 @@ def ncut_run(
|
|
| 509 |
progess_start=progress_start,
|
| 510 |
)
|
| 511 |
logging_str += _logging_str
|
|
|
|
| 512 |
|
| 513 |
|
| 514 |
if "AlignedThreeModelAttnNodes" == model_name:
|
|
@@ -528,8 +531,42 @@ def ncut_run(
|
|
| 528 |
inp = eigvecs.reshape(*features.shape[:-1], -1)
|
| 529 |
if recursion_metric == "cosine":
|
| 530 |
inp = F.normalize(inp, dim=-1)
|
| 531 |
-
|
| 532 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
if old_school_ncut: # individual images
|
| 534 |
logging_str += "Running NCut for each image independently\n"
|
| 535 |
rgb = []
|
|
@@ -643,7 +680,7 @@ def ncut_run(
|
|
| 643 |
norm_images = to_pil_images(norm_images)
|
| 644 |
logging_str += "Eigenvector Magnitude\n"
|
| 645 |
logging_str += f"Min: {vmin:.2f}, Max: {vmax:.2f}\n"
|
| 646 |
-
gr.Info(f"Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=
|
| 647 |
|
| 648 |
return to_pil_images(rgb), cluster_images, norm_images, logging_str
|
| 649 |
|
|
@@ -651,26 +688,26 @@ def ncut_run(
|
|
| 651 |
|
| 652 |
def _ncut_run(*args, **kwargs):
|
| 653 |
n_ret = kwargs.pop("n_ret", 1)
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
|
| 658 |
-
|
| 659 |
|
| 660 |
-
|
| 661 |
-
|
| 662 |
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
ret = ncut_run(*args, **kwargs)
|
| 672 |
-
ret = list(ret)[:n_ret] + [ret[-1]]
|
| 673 |
-
return ret
|
| 674 |
|
| 675 |
if USE_HUGGINGFACE_ZEROGPU:
|
| 676 |
@spaces.GPU(duration=30)
|
|
@@ -1407,7 +1444,7 @@ with demo:
|
|
| 1407 |
scroll_to_output=True,
|
| 1408 |
)
|
| 1409 |
|
| 1410 |
-
with gr.Tab('AlignedCut (
|
| 1411 |
|
| 1412 |
with gr.Row():
|
| 1413 |
with gr.Column(scale=5, min_width=200):
|
|
@@ -1582,6 +1619,78 @@ with demo:
|
|
| 1582 |
outputs=[l1_gallery, l2_gallery, l3_gallery, logging_text],
|
| 1583 |
api_name="API_RecursiveCut"
|
| 1584 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1585 |
|
| 1586 |
|
| 1587 |
with gr.Tab('Video'):
|
|
@@ -1741,7 +1850,7 @@ with demo:
|
|
| 1741 |
outputs=[output_gallery, logging_text],
|
| 1742 |
)
|
| 1743 |
|
| 1744 |
-
with gr.Tab('Model Aligned (
|
| 1745 |
gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
|
| 1746 |
gr.Markdown('---')
|
| 1747 |
gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
|
|
@@ -1965,8 +2074,9 @@ with demo:
|
|
| 1965 |
return gr.update()
|
| 1966 |
|
| 1967 |
hidden_button.change(update_smile, [n_smiles], [n_smiles, hidden_button])
|
| 1968 |
-
hidden_button.change(unlock_tabs_with_info, n_smiles,
|
| 1969 |
-
hidden_button.change(unlock_tabs, n_smiles,
|
|
|
|
| 1970 |
|
| 1971 |
with gr.Row():
|
| 1972 |
with gr.Column():
|
|
|
|
| 325 |
mask = mask[mask_sort_idx[:3]]
|
| 326 |
sort_values.append(mask.mean().item())
|
| 327 |
# fps_heatmaps[idx.item()] = heatmap.cpu()
|
| 328 |
+
fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:6]].cpu()
|
| 329 |
top3_image_idx[idx.item()] = mask_sort_idx[:3]
|
| 330 |
+
top10_image_idx[idx.item()] = mask_sort_idx[:6]
|
| 331 |
# do the sorting
|
| 332 |
_sort_idx = torch.tensor(sort_values).argsort(descending=True)
|
| 333 |
fps_idx = fps_idx[_sort_idx]
|
|
|
|
| 427 |
alignedcut_eig_norm_plot=False,
|
| 428 |
**kwargs,
|
| 429 |
):
|
| 430 |
+
advanced = kwargs.get("advanced", False)
|
| 431 |
progress = gr.Progress()
|
| 432 |
progress(0.2, desc="Feature Extraction")
|
| 433 |
|
|
|
|
| 484 |
|
| 485 |
if recursion:
|
| 486 |
rgbs = []
|
| 487 |
+
all_eigvecs = []
|
| 488 |
recursion_gammas = [recursion_l1_gamma, recursion_l2_gamma, recursion_l3_gamma]
|
| 489 |
inp = features
|
| 490 |
progress_start = 0.4
|
|
|
|
| 511 |
progess_start=progress_start,
|
| 512 |
)
|
| 513 |
logging_str += _logging_str
|
| 514 |
+
all_eigvecs.append(eigvecs.cpu().clone())
|
| 515 |
|
| 516 |
|
| 517 |
if "AlignedThreeModelAttnNodes" == model_name:
|
|
|
|
| 531 |
inp = eigvecs.reshape(*features.shape[:-1], -1)
|
| 532 |
if recursion_metric == "cosine":
|
| 533 |
inp = F.normalize(inp, dim=-1)
|
| 534 |
+
|
| 535 |
+
if not advanced:
|
| 536 |
+
return rgbs[0], rgbs[1], rgbs[2], logging_str
|
| 537 |
+
if advanced:
|
| 538 |
+
cluster_plots, norm_plots = [], []
|
| 539 |
+
for i in range(3):
|
| 540 |
+
eigvecs = all_eigvecs[i]
|
| 541 |
+
# add norm plot, cluster plot
|
| 542 |
+
start = time.time()
|
| 543 |
+
progress_start = 0.6
|
| 544 |
+
progress(progress_start, desc=f"Plotting Clusters Recursion #{i+1}")
|
| 545 |
+
h, w = features.shape[1], features.shape[2]
|
| 546 |
+
if torch.cuda.is_available():
|
| 547 |
+
images = images.cuda()
|
| 548 |
+
_images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
|
| 549 |
+
cluster_images, eig_magnitude = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start, advanced=advanced)
|
| 550 |
+
logging_str += f"Recursion #{i+1} plot time: {time.time() - start:.2f}s\n"
|
| 551 |
+
|
| 552 |
+
norm_images = []
|
| 553 |
+
vmin, vmax = eig_magnitude.min(), eig_magnitude.max()
|
| 554 |
+
eig_magnitude = (eig_magnitude - vmin) / (vmax - vmin)
|
| 555 |
+
eig_magnitude = eig_magnitude.cpu().numpy()
|
| 556 |
+
colormap = matplotlib.colormaps['Reds']
|
| 557 |
+
for i_image in range(eig_magnitude.shape[0]):
|
| 558 |
+
norm_image = colormap(eig_magnitude[i_image])
|
| 559 |
+
norm_images.append(torch.tensor(norm_image[..., :3]))
|
| 560 |
+
norm_images = to_pil_images(norm_images)
|
| 561 |
+
logging_str += f"Recursion #{i+1} Eigenvector Magnitude: [{vmin:.2f}, {vmax:.2f}]\n"
|
| 562 |
+
gr.Info(f"Recursion #{i+1} Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=10)
|
| 563 |
+
|
| 564 |
+
cluster_plots.append(cluster_images)
|
| 565 |
+
norm_plots.append(norm_images)
|
| 566 |
+
|
| 567 |
+
return *rgbs, *norm_plots, *cluster_plots, logging_str
|
| 568 |
+
|
| 569 |
+
|
| 570 |
if old_school_ncut: # individual images
|
| 571 |
logging_str += "Running NCut for each image independently\n"
|
| 572 |
rgb = []
|
|
|
|
| 680 |
norm_images = to_pil_images(norm_images)
|
| 681 |
logging_str += "Eigenvector Magnitude\n"
|
| 682 |
logging_str += f"Min: {vmin:.2f}, Max: {vmax:.2f}\n"
|
| 683 |
+
gr.Info(f"Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=10)
|
| 684 |
|
| 685 |
return to_pil_images(rgb), cluster_images, norm_images, logging_str
|
| 686 |
|
|
|
|
| 688 |
|
| 689 |
def _ncut_run(*args, **kwargs):
|
| 690 |
n_ret = kwargs.pop("n_ret", 1)
|
| 691 |
+
try:
|
| 692 |
+
if torch.cuda.is_available():
|
| 693 |
+
torch.cuda.empty_cache()
|
| 694 |
|
| 695 |
+
ret = ncut_run(*args, **kwargs)
|
| 696 |
|
| 697 |
+
if torch.cuda.is_available():
|
| 698 |
+
torch.cuda.empty_cache()
|
| 699 |
|
| 700 |
+
ret = list(ret)[:n_ret] + [ret[-1]]
|
| 701 |
+
return ret
|
| 702 |
+
except Exception as e:
|
| 703 |
+
gr.Error(str(e))
|
| 704 |
+
if torch.cuda.is_available():
|
| 705 |
+
torch.cuda.empty_cache()
|
| 706 |
+
return *(None for _ in range(n_ret)), "Error: " + str(e)
|
| 707 |
+
|
| 708 |
+
# ret = ncut_run(*args, **kwargs)
|
| 709 |
+
# ret = list(ret)[:n_ret] + [ret[-1]]
|
| 710 |
+
# return ret
|
| 711 |
|
| 712 |
if USE_HUGGINGFACE_ZEROGPU:
|
| 713 |
@spaces.GPU(duration=30)
|
|
|
|
| 1444 |
scroll_to_output=True,
|
| 1445 |
)
|
| 1446 |
|
| 1447 |
+
with gr.Tab('AlignedCut (Advanced)', visible=False) as tab_alignedcut_advanced:
|
| 1448 |
|
| 1449 |
with gr.Row():
|
| 1450 |
with gr.Column(scale=5, min_width=200):
|
|
|
|
| 1619 |
outputs=[l1_gallery, l2_gallery, l3_gallery, logging_text],
|
| 1620 |
api_name="API_RecursiveCut"
|
| 1621 |
)
|
| 1622 |
+
|
| 1623 |
+
with gr.Tab('Recursive Cut (Advanced)', visible=False) as tab_recursivecut_advanced:
|
| 1624 |
+
|
| 1625 |
+
with gr.Row():
|
| 1626 |
+
with gr.Column(scale=5, min_width=200):
|
| 1627 |
+
gr.Markdown('### Output (Recursion #1)')
|
| 1628 |
+
l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
| 1629 |
+
add_output_images_buttons(l1_gallery)
|
| 1630 |
+
l1_norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
| 1631 |
+
l1_cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=600, show_share_button=True, preview=True, interactive=False)
|
| 1632 |
+
with gr.Column(scale=5, min_width=200):
|
| 1633 |
+
gr.Markdown('### Output (Recursion #2)')
|
| 1634 |
+
l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
| 1635 |
+
add_output_images_buttons(l2_gallery)
|
| 1636 |
+
l2_norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
| 1637 |
+
l2_cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=600, show_share_button=True, preview=True, interactive=False)
|
| 1638 |
+
with gr.Column(scale=5, min_width=200):
|
| 1639 |
+
gr.Markdown('### Output (Recursion #3)')
|
| 1640 |
+
l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
| 1641 |
+
add_output_images_buttons(l3_gallery)
|
| 1642 |
+
l3_norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
| 1643 |
+
l3_cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height=600, show_share_button=True, preview=True, interactive=False)
|
| 1644 |
+
with gr.Row():
|
| 1645 |
+
with gr.Column(scale=5, min_width=200):
|
| 1646 |
+
input_gallery, submit_button, clear_images_button = make_input_images_section()
|
| 1647 |
+
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section(advanced=True)
|
| 1648 |
+
num_images_slider.value = 100
|
| 1649 |
+
clear_images_button.visible = False
|
| 1650 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
| 1651 |
+
|
| 1652 |
+
with gr.Column(scale=5, min_width=200):
|
| 1653 |
+
with gr.Accordion("➡️ Recursion config", open=True):
|
| 1654 |
+
l1_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #1: N eigenvectors", value=100, elem_id="l1_num_eig")
|
| 1655 |
+
l2_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #2: N eigenvectors", value=50, elem_id="l2_num_eig")
|
| 1656 |
+
l3_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #3: N eigenvectors", value=50, elem_id="l3_num_eig")
|
| 1657 |
+
metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="Recursion distance metric", value="cosine", elem_id="recursion_metric")
|
| 1658 |
+
l1_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #1: Affinity focal gamma", value=0.5, elem_id="recursion_l1_gamma")
|
| 1659 |
+
l2_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #2: Affinity focal gamma", value=0.5, elem_id="recursion_l2_gamma")
|
| 1660 |
+
l3_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #3: Affinity focal gamma", value=0.5, elem_id="recursion_l3_gamma")
|
| 1661 |
+
[
|
| 1662 |
+
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
| 1663 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
| 1664 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 1665 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
| 1666 |
+
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
| 1667 |
+
] = make_parameters_section()
|
| 1668 |
+
num_eig_slider.visible = False
|
| 1669 |
+
affinity_focal_gamma_slider.visible = False
|
| 1670 |
+
true_placeholder = gr.Checkbox(label="True placeholder", value=True, elem_id="true_placeholder")
|
| 1671 |
+
true_placeholder.visible = False
|
| 1672 |
+
false_placeholder = gr.Checkbox(label="False placeholder", value=False, elem_id="false_placeholder")
|
| 1673 |
+
false_placeholder.visible = False
|
| 1674 |
+
number_placeholder = gr.Number(0, label="Number placeholder", elem_id="number_placeholder")
|
| 1675 |
+
number_placeholder.visible = False
|
| 1676 |
+
clear_images_button.click(lambda x: ([],), outputs=[input_gallery])
|
| 1677 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
| 1678 |
+
|
| 1679 |
+
submit_button.click(
|
| 1680 |
+
partial(run_fn, n_ret=9, advanced=True),
|
| 1681 |
+
inputs=[
|
| 1682 |
+
input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown,
|
| 1683 |
+
positive_prompt, negative_prompt,
|
| 1684 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
| 1685 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
| 1686 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 1687 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown,
|
| 1688 |
+
false_placeholder, number_placeholder, true_placeholder,
|
| 1689 |
+
l2_num_eig_slider, l3_num_eig_slider, metric_dropdown,
|
| 1690 |
+
l1_affinity_focal_gamma_slider, l2_affinity_focal_gamma_slider, l3_affinity_focal_gamma_slider
|
| 1691 |
+
],
|
| 1692 |
+
outputs=[l1_gallery, l2_gallery, l3_gallery, l1_norm_gallery, l2_norm_gallery, l3_norm_gallery, l1_cluster_gallery, l2_cluster_gallery, l3_cluster_gallery, logging_text],
|
| 1693 |
+
)
|
| 1694 |
|
| 1695 |
|
| 1696 |
with gr.Tab('Video'):
|
|
|
|
| 1850 |
outputs=[output_gallery, logging_text],
|
| 1851 |
)
|
| 1852 |
|
| 1853 |
+
with gr.Tab('Model Aligned (Advanced)', visible=False) as tab_model_aligned_advanced:
|
| 1854 |
gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)')
|
| 1855 |
gr.Markdown('---')
|
| 1856 |
gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.')
|
|
|
|
| 2074 |
return gr.update()
|
| 2075 |
|
| 2076 |
hidden_button.change(update_smile, [n_smiles], [n_smiles, hidden_button])
|
| 2077 |
+
hidden_button.change(unlock_tabs_with_info, n_smiles, tab_alignedcut_advanced)
|
| 2078 |
+
hidden_button.change(unlock_tabs, n_smiles, tab_model_aligned_advanced)
|
| 2079 |
+
hidden_button.change(unlock_tabs, n_smiles, tab_recursivecut_advanced)
|
| 2080 |
|
| 2081 |
with gr.Row():
|
| 2082 |
with gr.Column():
|