Spaces:
Running
on
Zero
Running
on
Zero
update advanced plots
Browse files
app.py
CHANGED
|
@@ -257,15 +257,19 @@ def blend_image_with_heatmap(image, heatmap, opacity1=0.5, opacity2=0.5):
|
|
| 257 |
blended = (1 - opacity1) * image + opacity2 * heatmap
|
| 258 |
return blended.astype(np.uint8)
|
| 259 |
|
| 260 |
-
def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6):
|
| 261 |
progress = gr.Progress()
|
| 262 |
progress(progess_start, desc="Finding Clusters by FPS")
|
| 263 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 264 |
eigvecs = eigvecs.to(device)
|
| 265 |
from ncut_pytorch.ncut_pytorch import farthest_point_sampling
|
| 266 |
magnitude = torch.norm(eigvecs, dim=-1)
|
| 267 |
-
|
| 268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
ret_magnitude = magnitude.reshape(-1, h, w)
|
| 271 |
|
|
@@ -283,7 +287,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6):
|
|
| 283 |
right = F.normalize(right, dim=-1)
|
| 284 |
heatmap = left @ right.T
|
| 285 |
heatmap = F.normalize(heatmap, dim=-1)
|
| 286 |
-
num_samples = 80
|
| 287 |
if num_samples > fps_idx.shape[0]:
|
| 288 |
num_samples = fps_idx.shape[0]
|
| 289 |
r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
|
|
@@ -328,13 +332,19 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6):
|
|
| 328 |
# reverse the fps_idx
|
| 329 |
# fps_idx = fps_idx.flip(0)
|
| 330 |
# discard the big clusters
|
| 331 |
-
|
| 332 |
-
#
|
| 333 |
-
fps_idx = fps_idx[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
|
| 335 |
fig_images = []
|
| 336 |
i_cluster = 0
|
| 337 |
-
num_plots = 10
|
| 338 |
plot_step_float = (1.0 - progess_start) / num_plots
|
| 339 |
for i_fig in range(num_plots):
|
| 340 |
progress(progess_start + i_fig * plot_step_float, desc="Plotting Clusters")
|
|
@@ -609,7 +619,8 @@ def ncut_run(
|
|
| 609 |
if torch.cuda.is_available():
|
| 610 |
images = images.cuda()
|
| 611 |
_images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
|
| 612 |
-
|
|
|
|
| 613 |
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
| 614 |
|
| 615 |
norm_images = None
|
|
@@ -622,8 +633,10 @@ def ncut_run(
|
|
| 622 |
colormap = matplotlib.colormaps['Reds']
|
| 623 |
for i_image in range(eig_magnitude.shape[0]):
|
| 624 |
norm_image = colormap(eig_magnitude[i_image])
|
| 625 |
-
norm_image = (norm_image[..., :3] * 255).astype(np.uint8)
|
| 626 |
-
norm_images.append(Image.fromarray(norm_image))
|
|
|
|
|
|
|
| 627 |
logging_str += "Eigenvector Magnitude\n"
|
| 628 |
logging_str += f"Min: {vmin:.2f}, Max: {vmax:.2f}\n"
|
| 629 |
gr.Info(f"Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=0)
|
|
@@ -855,6 +868,7 @@ def run_fn(
|
|
| 855 |
n_ret=1,
|
| 856 |
plot_clusters=False,
|
| 857 |
alignedcut_eig_norm_plot=False,
|
|
|
|
| 858 |
):
|
| 859 |
|
| 860 |
progress=gr.Progress()
|
|
@@ -987,6 +1001,7 @@ def run_fn(
|
|
| 987 |
"n_ret": n_ret,
|
| 988 |
"plot_clusters": plot_clusters,
|
| 989 |
"alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
|
|
|
|
| 990 |
}
|
| 991 |
# print(kwargs)
|
| 992 |
|
|
@@ -1416,7 +1431,7 @@ with demo:
|
|
| 1416 |
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
| 1417 |
|
| 1418 |
submit_button.click(
|
| 1419 |
-
partial(run_fn, n_ret=3, plot_clusters=True, alignedcut_eig_norm_plot=True),
|
| 1420 |
inputs=[
|
| 1421 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
| 1422 |
positive_prompt, negative_prompt,
|
|
@@ -1802,7 +1817,7 @@ with demo:
|
|
| 1802 |
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
| 1803 |
|
| 1804 |
submit_button.click(
|
| 1805 |
-
partial(run_fn, n_ret=3),
|
| 1806 |
inputs=[
|
| 1807 |
input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown,
|
| 1808 |
positive_prompt, negative_prompt,
|
|
|
|
| 257 |
blended = (1 - opacity1) * image + opacity2 * heatmap
|
| 258 |
return blended.astype(np.uint8)
|
| 259 |
|
| 260 |
+
def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=False):
|
| 261 |
progress = gr.Progress()
|
| 262 |
progress(progess_start, desc="Finding Clusters by FPS")
|
| 263 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 264 |
eigvecs = eigvecs.to(device)
|
| 265 |
from ncut_pytorch.ncut_pytorch import farthest_point_sampling
|
| 266 |
magnitude = torch.norm(eigvecs, dim=-1)
|
| 267 |
+
|
| 268 |
+
# gr.Info("Finding Clusters by FPS, no magnitude filtering")
|
| 269 |
+
top_p_idx = torch.arange(eigvecs.shape[0])
|
| 270 |
+
# gr.Info("Finding Clusters by FPS, with magnitude filtering")
|
| 271 |
+
# p = 0.8
|
| 272 |
+
# top_p_idx = magnitude.argsort(descending=True)[:int(p * magnitude.shape[0])]
|
| 273 |
|
| 274 |
ret_magnitude = magnitude.reshape(-1, h, w)
|
| 275 |
|
|
|
|
| 287 |
right = F.normalize(right, dim=-1)
|
| 288 |
heatmap = left @ right.T
|
| 289 |
heatmap = F.normalize(heatmap, dim=-1)
|
| 290 |
+
num_samples = 80 if not advanced else 130
|
| 291 |
if num_samples > fps_idx.shape[0]:
|
| 292 |
num_samples = fps_idx.shape[0]
|
| 293 |
r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
|
|
|
|
| 332 |
# reverse the fps_idx
|
| 333 |
# fps_idx = fps_idx.flip(0)
|
| 334 |
# discard the big clusters
|
| 335 |
+
|
| 336 |
+
# gr.Info("Discarding the biggest 10 clusters")
|
| 337 |
+
# fps_idx = fps_idx[10:]
|
| 338 |
+
# gr.Info("Not discarding the biggest 10 clusters")
|
| 339 |
+
# gr.Info("Discarding the smallest 30 out of 80 sampled clusters")
|
| 340 |
+
|
| 341 |
+
if not advanced:
|
| 342 |
+
# shuffle the fps_idx
|
| 343 |
+
fps_idx = fps_idx[torch.randperm(fps_idx.shape[0])]
|
| 344 |
|
| 345 |
fig_images = []
|
| 346 |
i_cluster = 0
|
| 347 |
+
num_plots = 10 if not advanced else 20
|
| 348 |
plot_step_float = (1.0 - progess_start) / num_plots
|
| 349 |
for i_fig in range(num_plots):
|
| 350 |
progress(progess_start + i_fig * plot_step_float, desc="Plotting Clusters")
|
|
|
|
| 619 |
if torch.cuda.is_available():
|
| 620 |
images = images.cuda()
|
| 621 |
_images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
|
| 622 |
+
advanced = kwargs.get("advanced", False)
|
| 623 |
+
cluster_images, eig_magnitude = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start, advanced=advanced)
|
| 624 |
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
| 625 |
|
| 626 |
norm_images = None
|
|
|
|
| 633 |
colormap = matplotlib.colormaps['Reds']
|
| 634 |
for i_image in range(eig_magnitude.shape[0]):
|
| 635 |
norm_image = colormap(eig_magnitude[i_image])
|
| 636 |
+
# norm_image = (norm_image[..., :3] * 255).astype(np.uint8)
|
| 637 |
+
# norm_images.append(Image.fromarray(norm_image))
|
| 638 |
+
norm_images.append(torch.tensor(norm_image[..., :3]))
|
| 639 |
+
norm_images = to_pil_images(norm_images)
|
| 640 |
logging_str += "Eigenvector Magnitude\n"
|
| 641 |
logging_str += f"Min: {vmin:.2f}, Max: {vmax:.2f}\n"
|
| 642 |
gr.Info(f"Eigenvector Magnitude:</br> Min: {vmin:.2f}, Max: {vmax:.2f}", duration=0)
|
|
|
|
| 868 |
n_ret=1,
|
| 869 |
plot_clusters=False,
|
| 870 |
alignedcut_eig_norm_plot=False,
|
| 871 |
+
advanced=False,
|
| 872 |
):
|
| 873 |
|
| 874 |
progress=gr.Progress()
|
|
|
|
| 1001 |
"n_ret": n_ret,
|
| 1002 |
"plot_clusters": plot_clusters,
|
| 1003 |
"alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
|
| 1004 |
+
"advanced": advanced,
|
| 1005 |
}
|
| 1006 |
# print(kwargs)
|
| 1007 |
|
|
|
|
| 1431 |
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
| 1432 |
|
| 1433 |
submit_button.click(
|
| 1434 |
+
partial(run_fn, n_ret=3, plot_clusters=True, alignedcut_eig_norm_plot=True, advanced=True),
|
| 1435 |
inputs=[
|
| 1436 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
| 1437 |
positive_prompt, negative_prompt,
|
|
|
|
| 1817 |
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
| 1818 |
|
| 1819 |
submit_button.click(
|
| 1820 |
+
partial(run_fn, n_ret=3, advanced=True),
|
| 1821 |
inputs=[
|
| 1822 |
input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown,
|
| 1823 |
positive_prompt, negative_prompt,
|