Spaces:
Running
on
Zero
Running
on
Zero
add tree+image click
Browse files
app.py
CHANGED
|
@@ -13,6 +13,8 @@ import uuid
|
|
| 13 |
import zipfile
|
| 14 |
import multiprocessing as mp
|
| 15 |
|
|
|
|
|
|
|
| 16 |
from einops import rearrange
|
| 17 |
from matplotlib import pyplot as plt
|
| 18 |
import matplotlib
|
|
@@ -2472,26 +2474,32 @@ with demo:
|
|
| 2472 |
# Plot the t-SNE points with image heatmaps
|
| 2473 |
big_pil_image = plot_tsne_with_image_heatmaps(images, eigvecs, fps_eigvecs, tsne_embed, fps_tsne3d_rgb, max_display_dots)
|
| 2474 |
|
| 2475 |
-
return tsne_embed, edges, fps_eigvecs, fps_tsne3d_rgb, fps_idx, pil_image, big_pil_image
|
| 2476 |
|
| 2477 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2478 |
|
| 2479 |
run_hierarchical_button.click(
|
| 2480 |
run_fps_tsne_hierarchical,
|
| 2481 |
inputs=[input_gallery, eigvecs, num_sample_fps_slider, tsne_perplexity_slider, tsne3d_rgb, fps_hc_seed_slider],
|
| 2482 |
-
outputs=[tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, tsne_plot,
|
| 2483 |
)
|
| 2484 |
-
gr.Markdown('---')
|
| 2485 |
-
gr.Markdown('<h3 style="text-align: center;">↓ interactively inspect the hierarchical structure</h3>')
|
| 2486 |
-
gr.Markdown('---')
|
| 2487 |
with gr.Row():
|
| 2488 |
-
from gradio_image_prompter import ImagePrompter
|
| 2489 |
with gr.Column(scale=5, min_width=200) as tsne_select:
|
|
|
|
|
|
|
|
|
|
| 2490 |
tsne_prompt_image = ImagePrompter(show_label=True, elem_id="tsne_prompt_image", interactive=False, label="spectral-tSNE tree")
|
| 2491 |
# copy plot to tsne_prompt_image on change
|
| 2492 |
# tsne_plot.change(fn=lambda x: gr.update(value={'image': x}, interactive=True),
|
| 2493 |
# inputs=[tsne_plot], outputs=[tsne_prompt_image])
|
| 2494 |
with gr.Column(scale=5, min_width=200) as image_select:
|
|
|
|
|
|
|
|
|
|
| 2495 |
image_plot = ImagePrompter(show_label=True, elem_id="image_plot", interactive=False, label="NCUT spectral-tSNE")
|
| 2496 |
image_slider = gr.Slider(0, 100, step=1, label="Image Index", value=0, elem_id="image_slider", interactive=True)
|
| 2497 |
def update_image_prompt(image_slider, output_gallery):
|
|
@@ -2505,11 +2513,16 @@ with demo:
|
|
| 2505 |
image_slider.change(fn=update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
|
| 2506 |
output_gallery.change(fn=update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
|
| 2507 |
output_gallery.change(fn=lambda x: gr.update(maximum=len(x)-1, interactive=True), inputs=[output_gallery], outputs=[image_slider])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2508 |
with gr.Column(scale=5, min_width=200):
|
| 2509 |
gr.Markdown('<h3 style="text-align: center;">Help</h3>')
|
| 2510 |
with gr.Accordion("Instructions", open=True):
|
| 2511 |
gr.Markdown("""
|
| 2512 |
-
1. Click one dot on the
|
| 2513 |
- Only the last clicked dot will be used
|
| 2514 |
- Eraser is at top-right corner
|
| 2515 |
- Use the right-side Radio to switch tree/image
|
|
@@ -2524,32 +2537,34 @@ with demo:
|
|
| 2524 |
2. Cluster Heatmap: max of N cosine similarity to N points in the connected components.
|
| 2525 |
""")
|
| 2526 |
with gr.Column(scale=5, min_width=200):
|
| 2527 |
-
prompt_radio = gr.Radio(["Tree", "Image"], label="Where to click on?", value="Tree", elem_id="prompt_radio", show_label=True)
|
| 2528 |
granularity_slider = gr.Slider(1, 1000, step=1, label="Cluster Granularity (k)", value=100, elem_id="granularity")
|
| 2529 |
num_sample_fps_slider.change(fn=lambda x: gr.update(maximum=x, interactive=True), inputs=[num_sample_fps_slider], outputs=[granularity_slider])
|
| 2530 |
def updaste_tsne_plot_change_granularity(granularity, tsne_embed, edges, fps_tsne_rgb, tsne_prompt_image):
|
| 2531 |
# Plot the t-SNE points
|
| 2532 |
pil_image = plot_tsne_tree(tsne_embed, edges, fps_tsne_rgb, granularity)
|
| 2533 |
-
|
| 2534 |
-
return gr.update(value={'image': pil_image}, interactive=True)
|
| 2535 |
-
return gr.update(value={'image': pil_image, 'points': tsne_prompt_image['points']}, interactive=True)
|
| 2536 |
granularity_slider.change(updaste_tsne_plot_change_granularity,
|
| 2537 |
inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb, tsne_prompt_image],
|
| 2538 |
-
outputs=[
|
| 2539 |
tsne_plot.change(updaste_tsne_plot_change_granularity,
|
| 2540 |
inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb],
|
| 2541 |
-
outputs=[
|
| 2542 |
prompt_radio.change(update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
|
|
|
|
|
|
|
|
|
|
| 2543 |
run_inspection_button = gr.Button("🔴 RUN Inspection", elem_id="run_inspection", variant='primary')
|
| 2544 |
inspect_logging_text = gr.Textbox("Logging information", lines=3, label="Logging", elem_id="inspect_logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
| 2545 |
# output_slot_radio = gr.Radio([1, 2, 3], label="Output Row", value=1, elem_id="output_slot", show_label=True)
|
| 2546 |
|
| 2547 |
delete_all_output_button = gr.Button("❌ Delete All Output", elem_id="delete_all_output", variant='secondary')
|
| 2548 |
-
|
|
|
|
|
|
|
| 2549 |
image_select.visible = False
|
| 2550 |
-
tsne_select.visible = True
|
| 2551 |
-
prompt_radio.change(fn=lambda x: gr.update(visible=x=="Tree"), inputs=prompt_radio, outputs=[tsne_select])
|
| 2552 |
prompt_radio.change(fn=lambda x: gr.update(visible=x=="Image"), inputs=prompt_radio, outputs=[image_select])
|
|
|
|
| 2553 |
|
| 2554 |
MAX_ROWS = 20
|
| 2555 |
current_output_row = gr.State(0)
|
|
@@ -2634,19 +2649,21 @@ with demo:
|
|
| 2634 |
closest_idx = np.argmax(sim)
|
| 2635 |
return closest_idx, (_x_ratio, _y_ratio)
|
| 2636 |
|
| 2637 |
-
def find_closest_fps_point(prompt_radio, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs):
|
| 2638 |
try:
|
| 2639 |
if prompt_radio == "Tree":
|
| 2640 |
return find_closest_fps_point_for_tsne_tree_plot(tsne_prompt, tsne2d_embed)
|
| 2641 |
if prompt_radio == "Image":
|
| 2642 |
return find_closest_fps_point_for_image_prompt(image_prompt, i_image, eigvecs, fps_eigvecs)
|
|
|
|
|
|
|
| 2643 |
except:
|
| 2644 |
raise gr.Error("""No blue point is selected. <br/>Please left-click on the image to select a blue point. <br/>After reloading the image (e.g., change granularity), please use the eraser to remove the previous point, then click on the image to select a blue point.""")
|
| 2645 |
|
| 2646 |
-
def run_inspection(tsne_prompt, image_prompt, prompt_radio, current_output_row, tsne2d_embed, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, granularity, eigvecs, i_image, tsne3d_rgb, input_gallery, output_row_occupy, max_rows=MAX_ROWS):
|
| 2647 |
if len(tsne2d_embed) == 0:
|
| 2648 |
raise gr.Error("Please run FPS+Cluster first.")
|
| 2649 |
-
closest_idx, (_x, _y) = find_closest_fps_point(prompt_radio, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs)
|
| 2650 |
closest_rgb = fps_tsne_rgb[closest_idx]
|
| 2651 |
closest_rgb = (closest_rgb * 255).astype(np.uint8)
|
| 2652 |
|
|
@@ -2728,7 +2745,7 @@ with demo:
|
|
| 2728 |
|
| 2729 |
run_inspection_button.click(
|
| 2730 |
run_inspection,
|
| 2731 |
-
inputs=[tsne_prompt_image, image_plot, prompt_radio, current_output_row, tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, granularity_slider, eigvecs, image_slider, tsne3d_rgb, input_gallery, output_row_occupy],
|
| 2732 |
outputs=inspect_output_rows + output_tree_images + heatmap_galleries + text_blocks + [current_output_row, output_row_occupy, inspect_logging_text],
|
| 2733 |
)
|
| 2734 |
|
|
|
|
| 13 |
import zipfile
|
| 14 |
import multiprocessing as mp
|
| 15 |
|
| 16 |
+
from gradio_image_prompter import ImagePrompter
|
| 17 |
+
|
| 18 |
from einops import rearrange
|
| 19 |
from matplotlib import pyplot as plt
|
| 20 |
import matplotlib
|
|
|
|
| 2474 |
# Plot the t-SNE points with image heatmaps
|
| 2475 |
big_pil_image = plot_tsne_with_image_heatmaps(images, eigvecs, fps_eigvecs, tsne_embed, fps_tsne3d_rgb, max_display_dots)
|
| 2476 |
|
| 2477 |
+
return tsne_embed, edges, fps_eigvecs, fps_tsne3d_rgb, fps_idx, pil_image, gr.update(value={'image': big_pil_image, 'points': []}, interactive=True)
|
| 2478 |
|
| 2479 |
+
gr.Markdown('---')
|
| 2480 |
+
gr.Markdown('<h3 style="text-align: center;">↓ interactively inspect the hierarchical structure</h3>')
|
| 2481 |
+
gr.Markdown('---')
|
| 2482 |
+
# big_tsne_plot = gr.Image(label="spectral-tSNE tree [+ Cluster Heatmap]", elem_id="big_tsne_plot", interactive=False, format='png')
|
| 2483 |
+
tsne_image_plot = ImagePrompter(show_label=True, elem_id="tsne_image_plot", interactive=False, label="spectral-tSNE tree [+ Cluster Heatmap]")
|
| 2484 |
|
| 2485 |
run_hierarchical_button.click(
|
| 2486 |
run_fps_tsne_hierarchical,
|
| 2487 |
inputs=[input_gallery, eigvecs, num_sample_fps_slider, tsne_perplexity_slider, tsne3d_rgb, fps_hc_seed_slider],
|
| 2488 |
+
outputs=[tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, tsne_plot, tsne_image_plot],
|
| 2489 |
)
|
|
|
|
|
|
|
|
|
|
| 2490 |
with gr.Row():
|
|
|
|
| 2491 |
with gr.Column(scale=5, min_width=200) as tsne_select:
|
| 2492 |
+
gr.Markdown('---')
|
| 2493 |
+
gr.Markdown('<h3 style="text-align: center;">Please click on the image blow ↓</h3>')
|
| 2494 |
+
gr.Markdown('---')
|
| 2495 |
tsne_prompt_image = ImagePrompter(show_label=True, elem_id="tsne_prompt_image", interactive=False, label="spectral-tSNE tree")
|
| 2496 |
# copy plot to tsne_prompt_image on change
|
| 2497 |
# tsne_plot.change(fn=lambda x: gr.update(value={'image': x}, interactive=True),
|
| 2498 |
# inputs=[tsne_plot], outputs=[tsne_prompt_image])
|
| 2499 |
with gr.Column(scale=5, min_width=200) as image_select:
|
| 2500 |
+
gr.Markdown('---')
|
| 2501 |
+
gr.Markdown('<h3 style="text-align: center;">Please click on the image blow ↓</h3>')
|
| 2502 |
+
gr.Markdown('---')
|
| 2503 |
image_plot = ImagePrompter(show_label=True, elem_id="image_plot", interactive=False, label="NCUT spectral-tSNE")
|
| 2504 |
image_slider = gr.Slider(0, 100, step=1, label="Image Index", value=0, elem_id="image_slider", interactive=True)
|
| 2505 |
def update_image_prompt(image_slider, output_gallery):
|
|
|
|
| 2513 |
image_slider.change(fn=update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
|
| 2514 |
output_gallery.change(fn=update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
|
| 2515 |
output_gallery.change(fn=lambda x: gr.update(maximum=len(x)-1, interactive=True), inputs=[output_gallery], outputs=[image_slider])
|
| 2516 |
+
with gr.Column(scale=5, min_width=200) as tsne_image_select:
|
| 2517 |
+
gr.Markdown('---')
|
| 2518 |
+
gr.Markdown('<h3 style="text-align: center;">Please click on the image above ↑</h3>')
|
| 2519 |
+
gr.Markdown('---')
|
| 2520 |
+
tsne_non_prompt_image = gr.Image(label="spectral-tSNE tree", elem_id="tsne_non_prompt_image", interactive=False, format='png')
|
| 2521 |
with gr.Column(scale=5, min_width=200):
|
| 2522 |
gr.Markdown('<h3 style="text-align: center;">Help</h3>')
|
| 2523 |
with gr.Accordion("Instructions", open=True):
|
| 2524 |
gr.Markdown("""
|
| 2525 |
+
1. Click one dot on the image.
|
| 2526 |
- Only the last clicked dot will be used
|
| 2527 |
- Eraser is at top-right corner
|
| 2528 |
- Use the right-side Radio to switch tree/image
|
|
|
|
| 2537 |
2. Cluster Heatmap: max of N cosine similarity to N points in the connected components.
|
| 2538 |
""")
|
| 2539 |
with gr.Column(scale=5, min_width=200):
|
| 2540 |
+
prompt_radio = gr.Radio(["Tree [+Image]", "Image"], label="Where to click on?", value="Tree [+Image]", elem_id="prompt_radio", show_label=True)
|
| 2541 |
granularity_slider = gr.Slider(1, 1000, step=1, label="Cluster Granularity (k)", value=100, elem_id="granularity")
|
| 2542 |
num_sample_fps_slider.change(fn=lambda x: gr.update(maximum=x, interactive=True), inputs=[num_sample_fps_slider], outputs=[granularity_slider])
|
| 2543 |
def updaste_tsne_plot_change_granularity(granularity, tsne_embed, edges, fps_tsne_rgb, tsne_prompt_image):
|
| 2544 |
# Plot the t-SNE points
|
| 2545 |
pil_image = plot_tsne_tree(tsne_embed, edges, fps_tsne_rgb, granularity)
|
| 2546 |
+
return gr.update(value=pil_image, label=f"spectral-tSNE tree [k={granularity}]")
|
|
|
|
|
|
|
| 2547 |
granularity_slider.change(updaste_tsne_plot_change_granularity,
|
| 2548 |
inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb, tsne_prompt_image],
|
| 2549 |
+
outputs=[tsne_non_prompt_image])
|
| 2550 |
tsne_plot.change(updaste_tsne_plot_change_granularity,
|
| 2551 |
inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb],
|
| 2552 |
+
outputs=[tsne_non_prompt_image])
|
| 2553 |
prompt_radio.change(update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
|
| 2554 |
+
# prompt_radio.change(updaste_tsne_plot_change_granularity,
|
| 2555 |
+
# inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb, tsne_prompt_image],
|
| 2556 |
+
# outputs=[tsne_non_prompt_image])
|
| 2557 |
run_inspection_button = gr.Button("🔴 RUN Inspection", elem_id="run_inspection", variant='primary')
|
| 2558 |
inspect_logging_text = gr.Textbox("Logging information", lines=3, label="Logging", elem_id="inspect_logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
| 2559 |
# output_slot_radio = gr.Radio([1, 2, 3], label="Output Row", value=1, elem_id="output_slot", show_label=True)
|
| 2560 |
|
| 2561 |
delete_all_output_button = gr.Button("❌ Delete All Output", elem_id="delete_all_output", variant='secondary')
|
| 2562 |
+
|
| 2563 |
+
tsne_image_select.visible = True
|
| 2564 |
+
tsne_select.visible = False
|
| 2565 |
image_select.visible = False
|
|
|
|
|
|
|
| 2566 |
prompt_radio.change(fn=lambda x: gr.update(visible=x=="Image"), inputs=prompt_radio, outputs=[image_select])
|
| 2567 |
+
prompt_radio.change(fn=lambda x: gr.update(visible=x=="Tree [+Image]"), inputs=prompt_radio, outputs=[tsne_image_select])
|
| 2568 |
|
| 2569 |
MAX_ROWS = 20
|
| 2570 |
current_output_row = gr.State(0)
|
|
|
|
| 2649 |
closest_idx = np.argmax(sim)
|
| 2650 |
return closest_idx, (_x_ratio, _y_ratio)
|
| 2651 |
|
| 2652 |
+
def find_closest_fps_point(prompt_radio, tsne_image_prompt, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs):
|
| 2653 |
try:
|
| 2654 |
if prompt_radio == "Tree":
|
| 2655 |
return find_closest_fps_point_for_tsne_tree_plot(tsne_prompt, tsne2d_embed)
|
| 2656 |
if prompt_radio == "Image":
|
| 2657 |
return find_closest_fps_point_for_image_prompt(image_prompt, i_image, eigvecs, fps_eigvecs)
|
| 2658 |
+
if prompt_radio == "Tree [+Image]":
|
| 2659 |
+
return find_closest_fps_point_for_tsne_tree_plot(tsne_image_prompt, tsne2d_embed)
|
| 2660 |
except:
|
| 2661 |
raise gr.Error("""No blue point is selected. <br/>Please left-click on the image to select a blue point. <br/>After reloading the image (e.g., change granularity), please use the eraser to remove the previous point, then click on the image to select a blue point.""")
|
| 2662 |
|
| 2663 |
+
def run_inspection(tsne_image_prompt, tsne_prompt, image_prompt, prompt_radio, current_output_row, tsne2d_embed, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, granularity, eigvecs, i_image, tsne3d_rgb, input_gallery, output_row_occupy, max_rows=MAX_ROWS):
|
| 2664 |
if len(tsne2d_embed) == 0:
|
| 2665 |
raise gr.Error("Please run FPS+Cluster first.")
|
| 2666 |
+
closest_idx, (_x, _y) = find_closest_fps_point(prompt_radio, tsne_image_prompt, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs)
|
| 2667 |
closest_rgb = fps_tsne_rgb[closest_idx]
|
| 2668 |
closest_rgb = (closest_rgb * 255).astype(np.uint8)
|
| 2669 |
|
|
|
|
| 2745 |
|
| 2746 |
run_inspection_button.click(
|
| 2747 |
run_inspection,
|
| 2748 |
+
inputs=[tsne_image_plot, tsne_prompt_image, image_plot, prompt_radio, current_output_row, tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, granularity_slider, eigvecs, image_slider, tsne3d_rgb, input_gallery, output_row_occupy],
|
| 2749 |
outputs=inspect_output_rows + output_tree_images + heatmap_galleries + text_blocks + [current_output_row, output_row_occupy, inspect_logging_text],
|
| 2750 |
)
|
| 2751 |
|