Spaces:
Running
on
Zero
Running
on
Zero
add progress bar
Browse files- app.py +108 -30
- app_text.py +9 -1
app.py
CHANGED
|
@@ -77,7 +77,9 @@ def compute_ncut(
|
|
| 77 |
min_dist=0.1,
|
| 78 |
sampling_method="fps",
|
| 79 |
metric="cosine",
|
|
|
|
| 80 |
):
|
|
|
|
| 81 |
logging_str = ""
|
| 82 |
|
| 83 |
num_nodes = np.prod(features.shape[:-1])
|
|
@@ -88,6 +90,7 @@ def compute_ncut(
|
|
| 88 |
logging_str += f"Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.\n"
|
| 89 |
|
| 90 |
start = time.time()
|
|
|
|
| 91 |
eigvecs, eigvals = NCUT(
|
| 92 |
num_eig=num_eig,
|
| 93 |
num_sample=num_sample_ncut,
|
|
@@ -102,6 +105,7 @@ def compute_ncut(
|
|
| 102 |
logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
|
| 103 |
|
| 104 |
start = time.time()
|
|
|
|
| 105 |
_, rgb = eigenvector_to_rgb(
|
| 106 |
eigvecs,
|
| 107 |
method=embedding_method,
|
|
@@ -249,15 +253,34 @@ def blend_image_with_heatmap(image, heatmap, opacity1=0.5, opacity2=0.5):
|
|
| 249 |
blended = (1 - opacity1) * image + opacity2 * heatmap
|
| 250 |
return blended.astype(np.uint8)
|
| 251 |
|
| 252 |
-
def make_cluster_plot(eigvecs, images, h=64, w=64):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
from ncut_pytorch.ncut_pytorch import farthest_point_sampling
|
| 254 |
magnitude = torch.norm(eigvecs, dim=-1)
|
| 255 |
-
p = 0.
|
| 256 |
top_p_idx = magnitude.argsort(descending=True)[:int(p * magnitude.shape[0])]
|
| 257 |
-
num_samples =
|
|
|
|
|
|
|
| 258 |
fps_idx = farthest_point_sampling(eigvecs[top_p_idx], num_samples)
|
| 259 |
fps_idx = top_p_idx[fps_idx]
|
| 260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
# downsample to 256x256
|
| 262 |
images = F.interpolate(images, (256, 256), mode="bilinear")
|
| 263 |
images = images.cpu().numpy()
|
|
@@ -269,29 +292,57 @@ def make_cluster_plot(eigvecs, images, h=64, w=64):
|
|
| 269 |
# sort the fps_idx by the mean of the heatmap
|
| 270 |
fps_heatmaps = {}
|
| 271 |
sort_values = []
|
|
|
|
| 272 |
for _, idx in enumerate(fps_idx):
|
| 273 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 274 |
-
eigvecs = eigvecs.to(device)
|
| 275 |
heatmap = F.cosine_similarity(eigvecs, eigvecs[idx][None], dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
heatmap = heatmap.reshape(-1, h, w)
|
| 277 |
-
mask = (heatmap >
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
sort_values.append(mask.mean().item())
|
| 279 |
-
fps_heatmaps[idx.item()] = heatmap.cpu()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
|
| 281 |
fig_images = []
|
| 282 |
i_cluster = 0
|
| 283 |
-
|
|
|
|
|
|
|
|
|
|
| 284 |
fig, axs = plt.subplots(3, 5, figsize=(15, 9))
|
| 285 |
for ax in axs.flatten():
|
| 286 |
ax.axis("off")
|
| 287 |
for j, idx in enumerate(fps_idx[i_fig*5:i_fig*5+5]):
|
| 288 |
heatmap = fps_heatmaps[idx.item()]
|
| 289 |
-
mask = (heatmap > 0.1).float()
|
| 290 |
-
sorted_image_idxs = torch.argsort(mask.mean((1, 2)), descending=True)
|
| 291 |
size = (images.shape[1], images.shape[2])
|
| 292 |
heatmap = apply_reds_colormap(heatmap, size)
|
| 293 |
-
for i, image_idx in enumerate(sorted_image_idxs[:3]):
|
| 294 |
-
|
|
|
|
|
|
|
| 295 |
axs[i, j].imshow(_heatmap)
|
| 296 |
if i == 0:
|
| 297 |
axs[i, j].set_title(f"cluster {i_cluster+1}", fontsize=24)
|
|
@@ -348,6 +399,9 @@ def ncut_run(
|
|
| 348 |
lisa_prompt2="",
|
| 349 |
lisa_prompt3="",
|
| 350 |
):
|
|
|
|
|
|
|
|
|
|
| 351 |
logging_str = ""
|
| 352 |
if "AlignedThreeModelAttnNodes" == model_name:
|
| 353 |
# dirty patch for the alignedcut paper
|
|
@@ -396,12 +450,16 @@ def ncut_run(
|
|
| 396 |
# print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
|
| 397 |
logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
|
| 398 |
|
|
|
|
|
|
|
| 399 |
if recursion:
|
| 400 |
rgbs = []
|
| 401 |
recursion_gammas = [recursion_l1_gamma, recursion_l2_gamma, recursion_l3_gamma]
|
| 402 |
inp = features
|
|
|
|
| 403 |
for i, n_eigs in enumerate([num_eig, recursion_l2_n_eigs, recursion_l3_n_eigs]):
|
| 404 |
logging_str += f"Recursion #{i+1}\n"
|
|
|
|
| 405 |
rgb, _logging_str, eigvecs = compute_ncut(
|
| 406 |
inp,
|
| 407 |
num_eig=n_eigs,
|
|
@@ -417,6 +475,7 @@ def ncut_run(
|
|
| 417 |
min_dist=min_dist,
|
| 418 |
sampling_method=sampling_method,
|
| 419 |
metric="cosine" if i == 0 else recursion_metric,
|
|
|
|
| 420 |
)
|
| 421 |
logging_str += _logging_str
|
| 422 |
|
|
@@ -424,6 +483,7 @@ def ncut_run(
|
|
| 424 |
if "AlignedThreeModelAttnNodes" == model_name:
|
| 425 |
# dirty patch for the alignedcut paper
|
| 426 |
start = time.time()
|
|
|
|
| 427 |
pil_images = []
|
| 428 |
for i_image in range(rgb.shape[0]):
|
| 429 |
_im = plot_one_image_36_grid(images[i_image], rgb[i_image])
|
|
@@ -442,6 +502,8 @@ def ncut_run(
|
|
| 442 |
if old_school_ncut: # individual images
|
| 443 |
logging_str += "Running NCut for each image independently\n"
|
| 444 |
rgb = []
|
|
|
|
|
|
|
| 445 |
for i_image in range(features.shape[0]):
|
| 446 |
logging_str += f"Image #{i_image+1}\n"
|
| 447 |
feature = features[i_image]
|
|
@@ -459,6 +521,7 @@ def ncut_run(
|
|
| 459 |
n_neighbors=n_neighbors,
|
| 460 |
min_dist=min_dist,
|
| 461 |
sampling_method=sampling_method,
|
|
|
|
| 462 |
)
|
| 463 |
logging_str += _logging_str
|
| 464 |
rgb.append(_rgb[0])
|
|
@@ -486,6 +549,7 @@ def ncut_run(
|
|
| 486 |
if "AlignedThreeModelAttnNodes" == model_name:
|
| 487 |
# dirty patch for the alignedcut paper
|
| 488 |
start = time.time()
|
|
|
|
| 489 |
pil_images = []
|
| 490 |
for i_image in range(rgb.shape[0]):
|
| 491 |
_im = plot_one_image_36_grid(images[i_image], rgb[i_image])
|
|
@@ -506,15 +570,18 @@ def ncut_run(
|
|
| 506 |
|
| 507 |
if not video_output:
|
| 508 |
start = time.time()
|
|
|
|
|
|
|
| 509 |
h, w = features.shape[1], features.shape[2]
|
| 510 |
if torch.cuda.is_available():
|
| 511 |
images = images.cuda()
|
| 512 |
_images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
|
| 513 |
-
cluster_images = make_cluster_plot(eigvecs, _images, h=h, w=w)
|
| 514 |
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
| 515 |
|
| 516 |
|
| 517 |
if video_output:
|
|
|
|
| 518 |
video_path = get_random_path()
|
| 519 |
video_cache.add_video(video_path)
|
| 520 |
pil_images_to_video(to_pil_images(rgb), video_path)
|
|
@@ -526,26 +593,26 @@ def ncut_run(
|
|
| 526 |
|
| 527 |
def _ncut_run(*args, **kwargs):
|
| 528 |
n_ret = kwargs.pop("n_ret", 1)
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
|
| 533 |
-
|
| 534 |
|
| 535 |
-
|
| 536 |
-
|
| 537 |
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
ret = ncut_run(*args, **kwargs)
|
| 547 |
-
ret = list(ret)[:n_ret] + [ret[-1]]
|
| 548 |
-
return ret
|
| 549 |
|
| 550 |
if USE_HUGGINGFACE_ZEROGPU:
|
| 551 |
@spaces.GPU(duration=20)
|
|
@@ -744,10 +811,15 @@ def run_fn(
|
|
| 744 |
n_ret=1,
|
| 745 |
):
|
| 746 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 747 |
if images is None:
|
| 748 |
gr.Warning("No images selected.")
|
| 749 |
return *(None for _ in range(n_ret)), "No images selected."
|
| 750 |
|
|
|
|
| 751 |
video_output = False
|
| 752 |
if isinstance(images, str):
|
| 753 |
images = extract_video_frames(images, max_frames=max_frames)
|
|
@@ -767,6 +839,7 @@ def run_fn(
|
|
| 767 |
images = [transform_image(image, resolution=resolution, stablediffusion=stablediffusion) for image in images]
|
| 768 |
images = torch.stack(images)
|
| 769 |
|
|
|
|
| 770 |
|
| 771 |
if is_lisa:
|
| 772 |
import subprocess
|
|
@@ -976,10 +1049,13 @@ def make_dataset_images_section(advanced=False, is_random=False):
|
|
| 976 |
def load_dataset_images(is_advanced, dataset_name, num_images=10,
|
| 977 |
is_filter=True, filter_by_class_text="0,1,2",
|
| 978 |
is_random=False, seed=1):
|
|
|
|
|
|
|
| 979 |
if is_advanced == "Basic":
|
| 980 |
gr.Info("Loaded images from Ego-Exo4D")
|
| 981 |
return default_images
|
| 982 |
try:
|
|
|
|
| 983 |
dataset = load_dataset(dataset_name, trust_remote_code=True)
|
| 984 |
key = list(dataset.keys())[0]
|
| 985 |
dataset = dataset[key]
|
|
@@ -990,6 +1066,7 @@ def make_dataset_images_section(advanced=False, is_random=False):
|
|
| 990 |
num_images = len(dataset)
|
| 991 |
|
| 992 |
if is_filter:
|
|
|
|
| 993 |
classes = [int(i) for i in filter_by_class_text.split(",")]
|
| 994 |
labels = np.array(dataset['label'])
|
| 995 |
unique_labels = np.unique(labels)
|
|
@@ -1193,6 +1270,7 @@ with demo:
|
|
| 1193 |
with gr.Column(scale=5, min_width=200):
|
| 1194 |
input_gallery, submit_button, clear_images_button = make_input_images_section()
|
| 1195 |
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section()
|
|
|
|
| 1196 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
| 1197 |
|
| 1198 |
with gr.Column(scale=5, min_width=200):
|
|
|
|
| 77 |
min_dist=0.1,
|
| 78 |
sampling_method="fps",
|
| 79 |
metric="cosine",
|
| 80 |
+
progess_start=0.4,
|
| 81 |
):
|
| 82 |
+
progress = gr.Progress()
|
| 83 |
logging_str = ""
|
| 84 |
|
| 85 |
num_nodes = np.prod(features.shape[:-1])
|
|
|
|
| 90 |
logging_str += f"Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.\n"
|
| 91 |
|
| 92 |
start = time.time()
|
| 93 |
+
progress(progess_start+0.0, desc="NCut")
|
| 94 |
eigvecs, eigvals = NCUT(
|
| 95 |
num_eig=num_eig,
|
| 96 |
num_sample=num_sample_ncut,
|
|
|
|
| 105 |
logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
|
| 106 |
|
| 107 |
start = time.time()
|
| 108 |
+
progress(progess_start+0.01, desc="spectral-tSNE")
|
| 109 |
_, rgb = eigenvector_to_rgb(
|
| 110 |
eigvecs,
|
| 111 |
method=embedding_method,
|
|
|
|
| 253 |
blended = (1 - opacity1) * image + opacity2 * heatmap
|
| 254 |
return blended.astype(np.uint8)
|
| 255 |
|
| 256 |
+
def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6):
|
| 257 |
+
progress = gr.Progress()
|
| 258 |
+
progress(progess_start, desc="Finding Clusters by FPS")
|
| 259 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 260 |
+
eigvecs = eigvecs.to(device)
|
| 261 |
from ncut_pytorch.ncut_pytorch import farthest_point_sampling
|
| 262 |
magnitude = torch.norm(eigvecs, dim=-1)
|
| 263 |
+
p = 0.8
|
| 264 |
top_p_idx = magnitude.argsort(descending=True)[:int(p * magnitude.shape[0])]
|
| 265 |
+
num_samples = 300
|
| 266 |
+
if num_samples > top_p_idx.shape[0]:
|
| 267 |
+
num_samples = top_p_idx.shape[0]
|
| 268 |
fps_idx = farthest_point_sampling(eigvecs[top_p_idx], num_samples)
|
| 269 |
fps_idx = top_p_idx[fps_idx]
|
| 270 |
|
| 271 |
+
# fps round 2 on the heatmap
|
| 272 |
+
left = eigvecs[fps_idx, :].clone()
|
| 273 |
+
right = eigvecs.clone()
|
| 274 |
+
left = F.normalize(left, dim=-1)
|
| 275 |
+
right = F.normalize(right, dim=-1)
|
| 276 |
+
heatmap = left @ right.T
|
| 277 |
+
heatmap = F.normalize(heatmap, dim=-1)
|
| 278 |
+
num_samples = 80
|
| 279 |
+
if num_samples > fps_idx.shape[0]:
|
| 280 |
+
num_samples = fps_idx.shape[0]
|
| 281 |
+
r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
|
| 282 |
+
fps_idx = fps_idx[r2_fps_idx]
|
| 283 |
+
|
| 284 |
# downsample to 256x256
|
| 285 |
images = F.interpolate(images, (256, 256), mode="bilinear")
|
| 286 |
images = images.cpu().numpy()
|
|
|
|
| 292 |
# sort the fps_idx by the mean of the heatmap
|
| 293 |
fps_heatmaps = {}
|
| 294 |
sort_values = []
|
| 295 |
+
top3_image_idx = {}
|
| 296 |
for _, idx in enumerate(fps_idx):
|
|
|
|
|
|
|
| 297 |
heatmap = F.cosine_similarity(eigvecs, eigvecs[idx][None], dim=-1)
|
| 298 |
+
|
| 299 |
+
# def top_percentile(tensor, p=0.8, max_size=10000):
|
| 300 |
+
# tensor = tensor.clone().flatten()
|
| 301 |
+
# if tensor.shape[0] > max_size:
|
| 302 |
+
# tensor = tensor[torch.randperm(tensor.shape[0])[:max_size]]
|
| 303 |
+
# return tensor.quantile(p)
|
| 304 |
+
# top_p = top_percentile(heatmap, p=0.5)
|
| 305 |
+
top_p = 0.5
|
| 306 |
+
|
| 307 |
heatmap = heatmap.reshape(-1, h, w)
|
| 308 |
+
mask = (heatmap > top_p).float()
|
| 309 |
+
# take top 3 masks only
|
| 310 |
+
mask_sort_values = mask.mean((1, 2))
|
| 311 |
+
mask_sort_idx = torch.argsort(mask_sort_values, descending=True)
|
| 312 |
+
mask = mask[mask_sort_idx[:3]]
|
| 313 |
sort_values.append(mask.mean().item())
|
| 314 |
+
# fps_heatmaps[idx.item()] = heatmap.cpu()
|
| 315 |
+
fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:3]].cpu()
|
| 316 |
+
top3_image_idx[idx.item()] = mask_sort_idx[:3]
|
| 317 |
+
# do the sorting
|
| 318 |
+
_sort_idx = torch.tensor(sort_values).argsort(descending=True)
|
| 319 |
+
fps_idx = fps_idx[_sort_idx]
|
| 320 |
+
# reverse the fps_idx
|
| 321 |
+
# fps_idx = fps_idx.flip(0)
|
| 322 |
+
# discard the big clusters
|
| 323 |
+
fps_idx = fps_idx[10:]
|
| 324 |
+
# shuffle the fps_idx
|
| 325 |
+
fps_idx = fps_idx[torch.randperm(fps_idx.shape[0])]
|
| 326 |
|
| 327 |
fig_images = []
|
| 328 |
i_cluster = 0
|
| 329 |
+
num_plots = 10
|
| 330 |
+
plot_step_float = (1.0 - progess_start) / num_plots
|
| 331 |
+
for i_fig in range(num_plots):
|
| 332 |
+
progress(progess_start + i_fig * plot_step_float, desc="Plotting Clusters")
|
| 333 |
fig, axs = plt.subplots(3, 5, figsize=(15, 9))
|
| 334 |
for ax in axs.flatten():
|
| 335 |
ax.axis("off")
|
| 336 |
for j, idx in enumerate(fps_idx[i_fig*5:i_fig*5+5]):
|
| 337 |
heatmap = fps_heatmaps[idx.item()]
|
| 338 |
+
# mask = (heatmap > 0.1).float()
|
| 339 |
+
# sorted_image_idxs = torch.argsort(mask.mean((1, 2)), descending=True)
|
| 340 |
size = (images.shape[1], images.shape[2])
|
| 341 |
heatmap = apply_reds_colormap(heatmap, size)
|
| 342 |
+
# for i, image_idx in enumerate(sorted_image_idxs[:3]):
|
| 343 |
+
for i, image_idx in enumerate(top3_image_idx[idx.item()]):
|
| 344 |
+
# _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[image_idx])
|
| 345 |
+
_heatmap = blend_image_with_heatmap(images[image_idx], heatmap[i])
|
| 346 |
axs[i, j].imshow(_heatmap)
|
| 347 |
if i == 0:
|
| 348 |
axs[i, j].set_title(f"cluster {i_cluster+1}", fontsize=24)
|
|
|
|
| 399 |
lisa_prompt2="",
|
| 400 |
lisa_prompt3="",
|
| 401 |
):
|
| 402 |
+
progress = gr.Progress()
|
| 403 |
+
progress(0.2, desc="Feature Extraction")
|
| 404 |
+
|
| 405 |
logging_str = ""
|
| 406 |
if "AlignedThreeModelAttnNodes" == model_name:
|
| 407 |
# dirty patch for the alignedcut paper
|
|
|
|
| 450 |
# print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
|
| 451 |
logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
|
| 452 |
|
| 453 |
+
progress(0.4, desc="NCut")
|
| 454 |
+
|
| 455 |
if recursion:
|
| 456 |
rgbs = []
|
| 457 |
recursion_gammas = [recursion_l1_gamma, recursion_l2_gamma, recursion_l3_gamma]
|
| 458 |
inp = features
|
| 459 |
+
progress_start = 0.4
|
| 460 |
for i, n_eigs in enumerate([num_eig, recursion_l2_n_eigs, recursion_l3_n_eigs]):
|
| 461 |
logging_str += f"Recursion #{i+1}\n"
|
| 462 |
+
progress_start += + 0.1 * i
|
| 463 |
rgb, _logging_str, eigvecs = compute_ncut(
|
| 464 |
inp,
|
| 465 |
num_eig=n_eigs,
|
|
|
|
| 475 |
min_dist=min_dist,
|
| 476 |
sampling_method=sampling_method,
|
| 477 |
metric="cosine" if i == 0 else recursion_metric,
|
| 478 |
+
progess_start=progress_start,
|
| 479 |
)
|
| 480 |
logging_str += _logging_str
|
| 481 |
|
|
|
|
| 483 |
if "AlignedThreeModelAttnNodes" == model_name:
|
| 484 |
# dirty patch for the alignedcut paper
|
| 485 |
start = time.time()
|
| 486 |
+
progress(progress_start + 0.09, desc=f"Plotting Recursion {i+1}")
|
| 487 |
pil_images = []
|
| 488 |
for i_image in range(rgb.shape[0]):
|
| 489 |
_im = plot_one_image_36_grid(images[i_image], rgb[i_image])
|
|
|
|
| 502 |
if old_school_ncut: # individual images
|
| 503 |
logging_str += "Running NCut for each image independently\n"
|
| 504 |
rgb = []
|
| 505 |
+
progress_start = 0.4
|
| 506 |
+
step_float = 0.6 / features.shape[0]
|
| 507 |
for i_image in range(features.shape[0]):
|
| 508 |
logging_str += f"Image #{i_image+1}\n"
|
| 509 |
feature = features[i_image]
|
|
|
|
| 521 |
n_neighbors=n_neighbors,
|
| 522 |
min_dist=min_dist,
|
| 523 |
sampling_method=sampling_method,
|
| 524 |
+
progess_start=progress_start+step_float*i_image,
|
| 525 |
)
|
| 526 |
logging_str += _logging_str
|
| 527 |
rgb.append(_rgb[0])
|
|
|
|
| 549 |
if "AlignedThreeModelAttnNodes" == model_name:
|
| 550 |
# dirty patch for the alignedcut paper
|
| 551 |
start = time.time()
|
| 552 |
+
progress(0.6, desc="Plotting")
|
| 553 |
pil_images = []
|
| 554 |
for i_image in range(rgb.shape[0]):
|
| 555 |
_im = plot_one_image_36_grid(images[i_image], rgb[i_image])
|
|
|
|
| 570 |
|
| 571 |
if not video_output:
|
| 572 |
start = time.time()
|
| 573 |
+
progress_start = 0.6
|
| 574 |
+
progress(progress_start, desc="Plotting Clusters")
|
| 575 |
h, w = features.shape[1], features.shape[2]
|
| 576 |
if torch.cuda.is_available():
|
| 577 |
images = images.cuda()
|
| 578 |
_images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
|
| 579 |
+
cluster_images = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start)
|
| 580 |
logging_str += f"plot time: {time.time() - start:.2f}s\n"
|
| 581 |
|
| 582 |
|
| 583 |
if video_output:
|
| 584 |
+
progress(0.8, desc="Saving Video")
|
| 585 |
video_path = get_random_path()
|
| 586 |
video_cache.add_video(video_path)
|
| 587 |
pil_images_to_video(to_pil_images(rgb), video_path)
|
|
|
|
| 593 |
|
| 594 |
def _ncut_run(*args, **kwargs):
|
| 595 |
n_ret = kwargs.pop("n_ret", 1)
|
| 596 |
+
try:
|
| 597 |
+
if torch.cuda.is_available():
|
| 598 |
+
torch.cuda.empty_cache()
|
| 599 |
|
| 600 |
+
ret = ncut_run(*args, **kwargs)
|
| 601 |
|
| 602 |
+
if torch.cuda.is_available():
|
| 603 |
+
torch.cuda.empty_cache()
|
| 604 |
|
| 605 |
+
ret = list(ret)[:n_ret] + [ret[-1]]
|
| 606 |
+
return ret
|
| 607 |
+
except Exception as e:
|
| 608 |
+
gr.Error(str(e))
|
| 609 |
+
if torch.cuda.is_available():
|
| 610 |
+
torch.cuda.empty_cache()
|
| 611 |
+
return *(None for _ in range(n_ret)), "Error: " + str(e)
|
| 612 |
+
|
| 613 |
+
# ret = ncut_run(*args, **kwargs)
|
| 614 |
+
# ret = list(ret)[:n_ret] + [ret[-1]]
|
| 615 |
+
# return ret
|
| 616 |
|
| 617 |
if USE_HUGGINGFACE_ZEROGPU:
|
| 618 |
@spaces.GPU(duration=20)
|
|
|
|
| 811 |
n_ret=1,
|
| 812 |
):
|
| 813 |
|
| 814 |
+
progress=gr.Progress()
|
| 815 |
+
progress(0, desc="Starting")
|
| 816 |
+
|
| 817 |
+
|
| 818 |
if images is None:
|
| 819 |
gr.Warning("No images selected.")
|
| 820 |
return *(None for _ in range(n_ret)), "No images selected."
|
| 821 |
|
| 822 |
+
progress(0.05, desc="Processing Images")
|
| 823 |
video_output = False
|
| 824 |
if isinstance(images, str):
|
| 825 |
images = extract_video_frames(images, max_frames=max_frames)
|
|
|
|
| 839 |
images = [transform_image(image, resolution=resolution, stablediffusion=stablediffusion) for image in images]
|
| 840 |
images = torch.stack(images)
|
| 841 |
|
| 842 |
+
progress(0.1, desc="Downloading Model")
|
| 843 |
|
| 844 |
if is_lisa:
|
| 845 |
import subprocess
|
|
|
|
| 1049 |
def load_dataset_images(is_advanced, dataset_name, num_images=10,
|
| 1050 |
is_filter=True, filter_by_class_text="0,1,2",
|
| 1051 |
is_random=False, seed=1):
|
| 1052 |
+
progress = gr.Progress()
|
| 1053 |
+
progress(0, desc="Loading Images")
|
| 1054 |
if is_advanced == "Basic":
|
| 1055 |
gr.Info("Loaded images from Ego-Exo4D")
|
| 1056 |
return default_images
|
| 1057 |
try:
|
| 1058 |
+
progress(0.5, desc="Downloading Dataset")
|
| 1059 |
dataset = load_dataset(dataset_name, trust_remote_code=True)
|
| 1060 |
key = list(dataset.keys())[0]
|
| 1061 |
dataset = dataset[key]
|
|
|
|
| 1066 |
num_images = len(dataset)
|
| 1067 |
|
| 1068 |
if is_filter:
|
| 1069 |
+
progress(0.8, desc="Filtering Images")
|
| 1070 |
classes = [int(i) for i in filter_by_class_text.split(",")]
|
| 1071 |
labels = np.array(dataset['label'])
|
| 1072 |
unique_labels = np.unique(labels)
|
|
|
|
| 1270 |
with gr.Column(scale=5, min_width=200):
|
| 1271 |
input_gallery, submit_button, clear_images_button = make_input_images_section()
|
| 1272 |
dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section()
|
| 1273 |
+
num_images_slider.value = 30
|
| 1274 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
| 1275 |
|
| 1276 |
with gr.Column(scale=5, min_width=200):
|
app_text.py
CHANGED
|
@@ -150,6 +150,7 @@ def ncut_run(
|
|
| 150 |
min_dist=0.1,
|
| 151 |
sampling_method="fps",
|
| 152 |
):
|
|
|
|
| 153 |
logging_str = ""
|
| 154 |
if perplexity >= num_sample_tsne or n_neighbors >= num_sample_tsne:
|
| 155 |
# raise gr.Error("Perplexity must be less than the number of samples for t-SNE.")
|
|
@@ -163,6 +164,7 @@ def ncut_run(
|
|
| 163 |
|
| 164 |
node_type = node_type.split(":")[0].strip()
|
| 165 |
|
|
|
|
| 166 |
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
|
| 167 |
|
| 168 |
start = time.time()
|
|
@@ -180,6 +182,7 @@ def ncut_run(
|
|
| 180 |
# print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
|
| 181 |
logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
|
| 182 |
|
|
|
|
| 183 |
rgb, _logging_str, _ = compute_ncut(
|
| 184 |
features,
|
| 185 |
num_eig=num_eig,
|
|
@@ -197,6 +200,7 @@ def ncut_run(
|
|
| 197 |
logging_str += _logging_str
|
| 198 |
|
| 199 |
start = time.time()
|
|
|
|
| 200 |
title = f"{model_name}, Layer {layer}, {node_type}"
|
| 201 |
fig = make_plot(token_texts, rgb, title=title)
|
| 202 |
logging_str += f"Plotting time: {time.time() - start:.2f}s\n"
|
|
@@ -223,6 +227,8 @@ else:
|
|
| 223 |
return _ncut_run(*args, **kwargs)
|
| 224 |
|
| 225 |
def real_run(model_name, text, layer, node_type, num_eig, affinity_focal_gamma, num_sample_ncut, knn_ncut, embedding_method, num_sample_tsne, knn_tsne, perplexity, n_neighbors, min_dist, sampling_method):
|
|
|
|
|
|
|
| 226 |
model = TEXT_MODEL_DICT[model_name]()
|
| 227 |
return __ncut_run(model, text, model_name, layer, num_eig, node_type,
|
| 228 |
affinity_focal_gamma, num_sample_ncut, knn_ncut, embedding_method,
|
|
@@ -251,7 +257,9 @@ def make_demo():
|
|
| 251 |
clear_button = gr.Button("🗑️Clear", elem_id='clear_button', variant='stop')
|
| 252 |
with gr.Column(scale=5, min_width=200):
|
| 253 |
gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
|
| 254 |
-
|
|
|
|
|
|
|
| 255 |
layer = gr.Slider(1, 32, step=1, value=32, label="Layer")
|
| 256 |
node_type = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Node Type", value="block: sum of residual")
|
| 257 |
num_eig = gr.Slider(minimum=1, maximum=1000, step=1, value=100, label="Number of Eigenvectors")
|
|
|
|
| 150 |
min_dist=0.1,
|
| 151 |
sampling_method="fps",
|
| 152 |
):
|
| 153 |
+
progress = gr.Progress()
|
| 154 |
logging_str = ""
|
| 155 |
if perplexity >= num_sample_tsne or n_neighbors >= num_sample_tsne:
|
| 156 |
# raise gr.Error("Perplexity must be less than the number of samples for t-SNE.")
|
|
|
|
| 164 |
|
| 165 |
node_type = node_type.split(":")[0].strip()
|
| 166 |
|
| 167 |
+
progress(0.5, desc="Feature Extraction")
|
| 168 |
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
|
| 169 |
|
| 170 |
start = time.time()
|
|
|
|
| 182 |
# print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
|
| 183 |
logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
|
| 184 |
|
| 185 |
+
progress(0.6, desc="NCUT & spectral-tSNE")
|
| 186 |
rgb, _logging_str, _ = compute_ncut(
|
| 187 |
features,
|
| 188 |
num_eig=num_eig,
|
|
|
|
| 200 |
logging_str += _logging_str
|
| 201 |
|
| 202 |
start = time.time()
|
| 203 |
+
progress(0.8, desc="Plotting")
|
| 204 |
title = f"{model_name}, Layer {layer}, {node_type}"
|
| 205 |
fig = make_plot(token_texts, rgb, title=title)
|
| 206 |
logging_str += f"Plotting time: {time.time() - start:.2f}s\n"
|
|
|
|
| 227 |
return _ncut_run(*args, **kwargs)
|
| 228 |
|
| 229 |
def real_run(model_name, text, layer, node_type, num_eig, affinity_focal_gamma, num_sample_ncut, knn_ncut, embedding_method, num_sample_tsne, knn_tsne, perplexity, n_neighbors, min_dist, sampling_method):
|
| 230 |
+
progress = gr.Progress()
|
| 231 |
+
progress(0.1, desc="Downloading model")
|
| 232 |
model = TEXT_MODEL_DICT[model_name]()
|
| 233 |
return __ncut_run(model, text, model_name, layer, num_eig, node_type,
|
| 234 |
affinity_focal_gamma, num_sample_ncut, knn_ncut, embedding_method,
|
|
|
|
| 257 |
clear_button = gr.Button("🗑️Clear", elem_id='clear_button', variant='stop')
|
| 258 |
with gr.Column(scale=5, min_width=200):
|
| 259 |
gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
|
| 260 |
+
model_list = list(TEXT_MODEL_DICT.keys())
|
| 261 |
+
model_list = [model for model in model_list if model != "meta-llama/Meta-Llama-3-8B"]
|
| 262 |
+
model_name = gr.Dropdown(model_list, label="Model", value="meta-llama/Meta-Llama-3.1-8B")
|
| 263 |
layer = gr.Slider(1, 32, step=1, value=32, label="Layer")
|
| 264 |
node_type = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Node Type", value="block: sum of residual")
|
| 265 |
num_eig = gr.Slider(minimum=1, maximum=1000, step=1, value=100, label="Number of Eigenvectors")
|