Spaces:
Running
on
Zero
Running
on
Zero
add fg
Browse files
app.py
CHANGED
|
@@ -409,34 +409,40 @@ def blend_image_with_heatmap(image, heatmap, opacity1=0.5, opacity2=0.5):
|
|
| 409 |
return blended.astype(np.uint8)
|
| 410 |
|
| 411 |
|
| 412 |
-
def segment_fg_bg(images):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
|
| 414 |
-
images = F.interpolate(images, (
|
| 415 |
|
| 416 |
# model = load_alignedthreemodel()
|
| 417 |
model = load_model("CLIP(ViT-B-16/openai)")
|
| 418 |
from ncut_pytorch.backbone import resample_position_embeddings
|
| 419 |
pos_embed = model.model.visual.positional_embedding
|
| 420 |
-
pos_embed = resample_position_embeddings(pos_embed,
|
| 421 |
model.model.visual.positional_embedding = torch.nn.Parameter(pos_embed)
|
| 422 |
|
| 423 |
-
batch_size = 4
|
| 424 |
chunk_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
|
| 425 |
|
| 426 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 427 |
model.to(device)
|
| 428 |
-
|
| 429 |
-
|
|
|
|
|
|
|
| 430 |
|
| 431 |
fg_acts, bg_acts = [], []
|
| 432 |
for chunk_idx in chunk_idxs:
|
| 433 |
with torch.no_grad():
|
| 434 |
input_images = images[chunk_idx].to(device)
|
| 435 |
# transform the input images
|
| 436 |
-
|
|
|
|
| 437 |
# output = model(input_images)[:, 5]
|
| 438 |
-
output = model(input_images)['attn'][
|
| 439 |
-
fg_act = output[:,
|
| 440 |
bg_act = output[:, 0, 0].mean(0)
|
| 441 |
fg_acts.append(fg_act)
|
| 442 |
bg_acts.append(bg_act)
|
|
@@ -445,21 +451,6 @@ def segment_fg_bg(images):
|
|
| 445 |
fg_act = F.normalize(fg_act, dim=-1)
|
| 446 |
bg_act = F.normalize(bg_act, dim=-1)
|
| 447 |
|
| 448 |
-
# ref_image = default_images[0]
|
| 449 |
-
# image = Image.open(ref_image).convert("RGB").resize((224, 224), Image.Resampling.BILINEAR)
|
| 450 |
-
# image = torch.tensor(np.array(image)).permute(2, 0, 1).float().to(device)
|
| 451 |
-
# image = (image / 255.0 - means) / stds
|
| 452 |
-
# output = model(image)['attn'][6][0]
|
| 453 |
-
# # print(output.shape)
|
| 454 |
-
# # bg on the center
|
| 455 |
-
# fg_act = output[5, 5]
|
| 456 |
-
# # bg on the bottom left
|
| 457 |
-
# bg_act = output[0, 0]
|
| 458 |
-
# fg_act = F.normalize(fg_act, dim=-1)
|
| 459 |
-
# bg_act = F.normalize(bg_act, dim=-1)
|
| 460 |
-
|
| 461 |
-
# print(images.mean(), images.std())
|
| 462 |
-
|
| 463 |
fg_act, bg_act = fg_act.to(device), bg_act.to(device)
|
| 464 |
chunk_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
|
| 465 |
heatmap_fgs, heatmap_bgs = [], []
|
|
@@ -467,9 +458,10 @@ def segment_fg_bg(images):
|
|
| 467 |
with torch.no_grad():
|
| 468 |
input_images = images[chunk_idx].to(device)
|
| 469 |
# transform the input images
|
| 470 |
-
|
|
|
|
| 471 |
# output = model(input_images)[:, 5]
|
| 472 |
-
output = model(input_images)['attn'][
|
| 473 |
output = F.normalize(output, dim=-1)
|
| 474 |
heatmap_fg = output @ fg_act[:, None] # [B, H, W, 1]
|
| 475 |
heatmap_bg = output @ bg_act[:, None] # [B, H, W, 1]
|
|
@@ -868,6 +860,71 @@ def ncut_run(
|
|
| 868 |
return to_pil_images(rgb), logging_str
|
| 869 |
|
| 870 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 871 |
|
| 872 |
# ailgnedcut
|
| 873 |
if not directed:
|
|
@@ -1037,9 +1094,9 @@ def _ncut_run(*args, **kwargs):
|
|
| 1037 |
torch.cuda.empty_cache()
|
| 1038 |
return *(None for _ in range(n_ret)), "Error: " + str(e)
|
| 1039 |
|
| 1040 |
-
|
| 1041 |
-
|
| 1042 |
-
|
| 1043 |
|
| 1044 |
if USE_HUGGINGFACE_ZEROGPU:
|
| 1045 |
@spaces.GPU(duration=30)
|
|
@@ -1250,6 +1307,7 @@ def run_fn(
|
|
| 1250 |
node_type2="k",
|
| 1251 |
head_index_text='all',
|
| 1252 |
make_symmetric=False,
|
|
|
|
| 1253 |
n_ret=1,
|
| 1254 |
plot_clusters=False,
|
| 1255 |
alignedcut_eig_norm_plot=False,
|
|
@@ -1258,6 +1316,7 @@ def run_fn(
|
|
| 1258 |
only_eigvecs=False,
|
| 1259 |
return_eigvec_and_rgb=False,
|
| 1260 |
normalize_eigvec_return=False,
|
|
|
|
| 1261 |
):
|
| 1262 |
# print(node_type2, head_index_text, make_symmetric)
|
| 1263 |
progress=gr.Progress()
|
|
@@ -1390,6 +1449,7 @@ def run_fn(
|
|
| 1390 |
"lisa_prompt2": lisa_prompt2,
|
| 1391 |
"lisa_prompt3": lisa_prompt3,
|
| 1392 |
"is_lisa": is_lisa,
|
|
|
|
| 1393 |
"n_ret": n_ret,
|
| 1394 |
"plot_clusters": plot_clusters,
|
| 1395 |
"alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
|
|
@@ -1401,6 +1461,7 @@ def run_fn(
|
|
| 1401 |
"only_eigvecs": only_eigvecs,
|
| 1402 |
"return_eigvec_and_rgb": return_eigvec_and_rgb,
|
| 1403 |
"normalize_eigvec_return": normalize_eigvec_return,
|
|
|
|
| 1404 |
}
|
| 1405 |
# print(kwargs)
|
| 1406 |
|
|
@@ -2196,7 +2257,8 @@ demo = gr.Blocks(
|
|
| 2196 |
css=custom_css,
|
| 2197 |
)
|
| 2198 |
with demo:
|
| 2199 |
-
|
|
|
|
| 2200 |
with gr.Tab('PlayGround'):
|
| 2201 |
eigvecs = gr.State(np.array([]))
|
| 2202 |
tsne3d_rgb = gr.State(np.array([]))
|
|
@@ -4247,6 +4309,45 @@ with demo:
|
|
| 4247 |
outputs=[mask_gallery, crop_gallery])
|
| 4248 |
|
| 4249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4250 |
with gr.Tab('Sub-cluster (dev)', visible=False) as sub_cluster_tab:
|
| 4251 |
with gr.Row():
|
| 4252 |
image_cluster_plot = gr.Image(value=None, label="Image-level clustering", elem_id="image_cluster_plot", interactive=False)
|
|
|
|
| 409 |
return blended.astype(np.uint8)
|
| 410 |
|
| 411 |
|
| 412 |
+
def segment_fg_bg(images, hw=224, i_layer=6, batch_size=4, transform_images=True):
|
| 413 |
+
|
| 414 |
+
assert hw % 16 == 0, "The height and width of the image must be divisible by 16."
|
| 415 |
+
psz = hw // 16
|
| 416 |
+
center_xy = (psz-1) // 2
|
| 417 |
|
| 418 |
+
images = F.interpolate(images, (hw, hw), mode="bilinear")
|
| 419 |
|
| 420 |
# model = load_alignedthreemodel()
|
| 421 |
model = load_model("CLIP(ViT-B-16/openai)")
|
| 422 |
from ncut_pytorch.backbone import resample_position_embeddings
|
| 423 |
pos_embed = model.model.visual.positional_embedding
|
| 424 |
+
pos_embed = resample_position_embeddings(pos_embed, psz, psz)
|
| 425 |
model.model.visual.positional_embedding = torch.nn.Parameter(pos_embed)
|
| 426 |
|
|
|
|
| 427 |
chunk_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
|
| 428 |
|
| 429 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 430 |
model.to(device)
|
| 431 |
+
|
| 432 |
+
if transform_images:
|
| 433 |
+
means = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
| 434 |
+
stds = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
| 435 |
|
| 436 |
fg_acts, bg_acts = [], []
|
| 437 |
for chunk_idx in chunk_idxs:
|
| 438 |
with torch.no_grad():
|
| 439 |
input_images = images[chunk_idx].to(device)
|
| 440 |
# transform the input images
|
| 441 |
+
if transform_images:
|
| 442 |
+
input_images = (input_images - means) / stds
|
| 443 |
# output = model(input_images)[:, 5]
|
| 444 |
+
output = model(input_images)['attn'][i_layer] # [B, H=14, W=14, C]
|
| 445 |
+
fg_act = output[:, center_xy, center_xy].mean(0)
|
| 446 |
bg_act = output[:, 0, 0].mean(0)
|
| 447 |
fg_acts.append(fg_act)
|
| 448 |
bg_acts.append(bg_act)
|
|
|
|
| 451 |
fg_act = F.normalize(fg_act, dim=-1)
|
| 452 |
bg_act = F.normalize(bg_act, dim=-1)
|
| 453 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
fg_act, bg_act = fg_act.to(device), bg_act.to(device)
|
| 455 |
chunk_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
|
| 456 |
heatmap_fgs, heatmap_bgs = [], []
|
|
|
|
| 458 |
with torch.no_grad():
|
| 459 |
input_images = images[chunk_idx].to(device)
|
| 460 |
# transform the input images
|
| 461 |
+
if transform_images:
|
| 462 |
+
input_images = (input_images - means) / stds
|
| 463 |
# output = model(input_images)[:, 5]
|
| 464 |
+
output = model(input_images)['attn'][i_layer]
|
| 465 |
output = F.normalize(output, dim=-1)
|
| 466 |
heatmap_fg = output @ fg_act[:, None] # [B, H, W, 1]
|
| 467 |
heatmap_bg = output @ bg_act[:, None] # [B, H, W, 1]
|
|
|
|
| 860 |
return to_pil_images(rgb), logging_str
|
| 861 |
|
| 862 |
|
| 863 |
+
# fg-bg separated
|
| 864 |
+
separate_fg_bg = kwargs.get("separate_fg_bg", False)
|
| 865 |
+
if separate_fg_bg:
|
| 866 |
+
fg_threshold = kwargs.get("fg_threshold", 0.5)
|
| 867 |
+
feature_hw = features.shape[1]
|
| 868 |
+
progress(0.4, desc="Segmenting FG-BG")
|
| 869 |
+
heatmap_fg, heatmap_bg = segment_fg_bg(images, hw=448, transform_images=False, i_layer=4)
|
| 870 |
+
heatmap_fg = 1 - heatmap_fg
|
| 871 |
+
heatmap_bg = 1 - heatmap_bg
|
| 872 |
+
b, h, w, c = features.shape
|
| 873 |
+
heatmap_bg = rearrange(heatmap_bg, 'b h w c -> b c h w')
|
| 874 |
+
heatmap_fg = rearrange(heatmap_fg, 'b h w c -> b c h w')
|
| 875 |
+
is_cuda = torch.cuda.is_available()
|
| 876 |
+
heatmap_fg = F.interpolate(heatmap_fg, (h, w), mode="bicubic")
|
| 877 |
+
heatmap_bg = F.interpolate(heatmap_bg, (h, w), mode="bicubic")
|
| 878 |
+
heatmap_fg = heatmap_fg.flatten()
|
| 879 |
+
heatmap_bg = heatmap_bg.flatten()
|
| 880 |
+
fg_minus_bg = heatmap_fg - heatmap_bg
|
| 881 |
+
|
| 882 |
+
def _to_mask(heatmap, threshold, gamma=0.5):
|
| 883 |
+
heatmap = (heatmap - heatmap.mean()) / heatmap.std()
|
| 884 |
+
heatmap = heatmap.double()
|
| 885 |
+
heatmap = torch.exp(heatmap)
|
| 886 |
+
heatmap = 1 / heatmap ** gamma
|
| 887 |
+
if heatmap.shape[0] > 10000:
|
| 888 |
+
np.random.seed(0)
|
| 889 |
+
random_idx = np.random.choice(heatmap.shape[0], 10000, replace=False)
|
| 890 |
+
vmin, vmax = heatmap[random_idx].quantile(0.01), heatmap[random_idx].quantile(0.99)
|
| 891 |
+
else:
|
| 892 |
+
vmin, vmax = heatmap.quantile(0.01), heatmap.quantile(0.99)
|
| 893 |
+
heatmap = (heatmap - vmin) / (vmax - vmin)
|
| 894 |
+
heatmap = heatmap.reshape(b, h, w)
|
| 895 |
+
mask = heatmap > threshold
|
| 896 |
+
return mask
|
| 897 |
+
|
| 898 |
+
fg_mask = _to_mask(fg_minus_bg, fg_threshold)
|
| 899 |
+
features_fg = features.flatten(0, 2)[fg_mask.flatten()]
|
| 900 |
+
|
| 901 |
+
progress(0.4, desc="NCut FG")
|
| 902 |
+
rgb, _logging_str, eigvecs = compute_ncut(
|
| 903 |
+
features_fg,
|
| 904 |
+
num_eig=num_eig,
|
| 905 |
+
num_sample_ncut=num_sample_ncut,
|
| 906 |
+
affinity_focal_gamma=affinity_focal_gamma,
|
| 907 |
+
knn_ncut=knn_ncut,
|
| 908 |
+
knn_tsne=knn_tsne,
|
| 909 |
+
num_sample_tsne=num_sample_tsne,
|
| 910 |
+
embedding_method=embedding_method,
|
| 911 |
+
embedding_metric=embedding_metric,
|
| 912 |
+
perplexity=perplexity,
|
| 913 |
+
n_neighbors=n_neighbors,
|
| 914 |
+
min_dist=min_dist,
|
| 915 |
+
sampling_method=sampling_method,
|
| 916 |
+
indirect_connection=indirect_connection,
|
| 917 |
+
make_orthogonal=make_orthogonal,
|
| 918 |
+
metric=ncut_metric,
|
| 919 |
+
only_eigvecs=False,
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
rgb_all = torch.zeros(b, h, w, 3)
|
| 923 |
+
rgb_all_flat = rgb_all.flatten(0, 2)
|
| 924 |
+
rgb_all_flat[fg_mask.flatten()] = rgb
|
| 925 |
+
rgb_all = rgb_all_flat.reshape(b, h, w, 3)
|
| 926 |
+
|
| 927 |
+
return to_pil_images(rgb_all), logging_str
|
| 928 |
|
| 929 |
# ailgnedcut
|
| 930 |
if not directed:
|
|
|
|
| 1094 |
torch.cuda.empty_cache()
|
| 1095 |
return *(None for _ in range(n_ret)), "Error: " + str(e)
|
| 1096 |
|
| 1097 |
+
ret = ncut_run(*args, **kwargs)
|
| 1098 |
+
ret = list(ret)[:n_ret] + [ret[-1]]
|
| 1099 |
+
return ret
|
| 1100 |
|
| 1101 |
if USE_HUGGINGFACE_ZEROGPU:
|
| 1102 |
@spaces.GPU(duration=30)
|
|
|
|
| 1307 |
node_type2="k",
|
| 1308 |
head_index_text='all',
|
| 1309 |
make_symmetric=False,
|
| 1310 |
+
fg_threshold=0.5,
|
| 1311 |
n_ret=1,
|
| 1312 |
plot_clusters=False,
|
| 1313 |
alignedcut_eig_norm_plot=False,
|
|
|
|
| 1316 |
only_eigvecs=False,
|
| 1317 |
return_eigvec_and_rgb=False,
|
| 1318 |
normalize_eigvec_return=False,
|
| 1319 |
+
separate_fg_bg=False,
|
| 1320 |
):
|
| 1321 |
# print(node_type2, head_index_text, make_symmetric)
|
| 1322 |
progress=gr.Progress()
|
|
|
|
| 1449 |
"lisa_prompt2": lisa_prompt2,
|
| 1450 |
"lisa_prompt3": lisa_prompt3,
|
| 1451 |
"is_lisa": is_lisa,
|
| 1452 |
+
"fg_threshold": fg_threshold,
|
| 1453 |
"n_ret": n_ret,
|
| 1454 |
"plot_clusters": plot_clusters,
|
| 1455 |
"alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
|
|
|
|
| 1461 |
"only_eigvecs": only_eigvecs,
|
| 1462 |
"return_eigvec_and_rgb": return_eigvec_and_rgb,
|
| 1463 |
"normalize_eigvec_return": normalize_eigvec_return,
|
| 1464 |
+
"separate_fg_bg": separate_fg_bg,
|
| 1465 |
}
|
| 1466 |
# print(kwargs)
|
| 1467 |
|
|
|
|
| 2257 |
css=custom_css,
|
| 2258 |
)
|
| 2259 |
with demo:
|
| 2260 |
+
|
| 2261 |
+
|
| 2262 |
with gr.Tab('PlayGround'):
|
| 2263 |
eigvecs = gr.State(np.array([]))
|
| 2264 |
tsne3d_rgb = gr.State(np.array([]))
|
|
|
|
| 4309 |
outputs=[mask_gallery, crop_gallery])
|
| 4310 |
|
| 4311 |
|
| 4312 |
+
with gr.Tab('FG'):
|
| 4313 |
+
|
| 4314 |
+
with gr.Row():
|
| 4315 |
+
with gr.Column(scale=5, min_width=200):
|
| 4316 |
+
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section()
|
| 4317 |
+
num_images_slider.value = 30
|
| 4318 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
| 4319 |
+
|
| 4320 |
+
with gr.Column(scale=5, min_width=200):
|
| 4321 |
+
output_gallery = make_output_images_section()
|
| 4322 |
+
fg_threshold_slider = gr.Slider(0.01, 1, step=0.01, label="Foreground threshold", value=0.5, elem_id="fg_threshold", info="increase for more foreground")
|
| 4323 |
+
# cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
| 4324 |
+
[
|
| 4325 |
+
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
| 4326 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
| 4327 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 4328 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
| 4329 |
+
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
| 4330 |
+
] = make_parameters_section()
|
| 4331 |
+
|
| 4332 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
| 4333 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
| 4334 |
+
|
| 4335 |
+
submit_button.click(
|
| 4336 |
+
partial(run_fn, n_ret=1, plot_clusters=False, separate_fg_bg=True),
|
| 4337 |
+
inputs=[
|
| 4338 |
+
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
| 4339 |
+
positive_prompt, negative_prompt,
|
| 4340 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
| 4341 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
| 4342 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 4343 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown,
|
| 4344 |
+
*[false_placeholder]*12,
|
| 4345 |
+
fg_threshold_slider
|
| 4346 |
+
],
|
| 4347 |
+
outputs=[output_gallery, logging_text],
|
| 4348 |
+
)
|
| 4349 |
+
|
| 4350 |
+
|
| 4351 |
with gr.Tab('Sub-cluster (dev)', visible=False) as sub_cluster_tab:
|
| 4352 |
with gr.Row():
|
| 4353 |
image_cluster_plot = gr.Image(value=None, label="Image-level clustering", elem_id="image_cluster_plot", interactive=False)
|